agentic-team-templates 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +280 -0
- package/bin/cli.js +5 -0
- package/package.json +47 -0
- package/src/index.js +521 -0
- package/templates/_shared/code-quality.md +162 -0
- package/templates/_shared/communication.md +114 -0
- package/templates/_shared/core-principles.md +62 -0
- package/templates/_shared/git-workflow.md +165 -0
- package/templates/_shared/security-fundamentals.md +173 -0
- package/templates/blockchain/.cursorrules/defi-patterns.md +520 -0
- package/templates/blockchain/.cursorrules/gas-optimization.md +339 -0
- package/templates/blockchain/.cursorrules/overview.md +130 -0
- package/templates/blockchain/.cursorrules/security.md +318 -0
- package/templates/blockchain/.cursorrules/smart-contracts.md +364 -0
- package/templates/blockchain/.cursorrules/testing.md +415 -0
- package/templates/blockchain/.cursorrules/web3-integration.md +538 -0
- package/templates/blockchain/CLAUDE.md +389 -0
- package/templates/cli-tools/.cursorrules/architecture.md +412 -0
- package/templates/cli-tools/.cursorrules/arguments.md +406 -0
- package/templates/cli-tools/.cursorrules/distribution.md +546 -0
- package/templates/cli-tools/.cursorrules/error-handling.md +455 -0
- package/templates/cli-tools/.cursorrules/overview.md +136 -0
- package/templates/cli-tools/.cursorrules/testing.md +537 -0
- package/templates/cli-tools/.cursorrules/user-experience.md +545 -0
- package/templates/cli-tools/CLAUDE.md +356 -0
- package/templates/data-engineering/.cursorrules/data-modeling.md +367 -0
- package/templates/data-engineering/.cursorrules/data-quality.md +455 -0
- package/templates/data-engineering/.cursorrules/overview.md +85 -0
- package/templates/data-engineering/.cursorrules/performance.md +339 -0
- package/templates/data-engineering/.cursorrules/pipeline-design.md +280 -0
- package/templates/data-engineering/.cursorrules/security.md +460 -0
- package/templates/data-engineering/.cursorrules/testing.md +452 -0
- package/templates/data-engineering/CLAUDE.md +974 -0
- package/templates/devops-sre/.cursorrules/capacity-planning.md +653 -0
- package/templates/devops-sre/.cursorrules/change-management.md +584 -0
- package/templates/devops-sre/.cursorrules/chaos-engineering.md +651 -0
- package/templates/devops-sre/.cursorrules/disaster-recovery.md +641 -0
- package/templates/devops-sre/.cursorrules/incident-management.md +565 -0
- package/templates/devops-sre/.cursorrules/observability.md +714 -0
- package/templates/devops-sre/.cursorrules/overview.md +230 -0
- package/templates/devops-sre/.cursorrules/postmortems.md +588 -0
- package/templates/devops-sre/.cursorrules/runbooks.md +760 -0
- package/templates/devops-sre/.cursorrules/slo-sli.md +617 -0
- package/templates/devops-sre/.cursorrules/toil-reduction.md +567 -0
- package/templates/devops-sre/CLAUDE.md +1007 -0
- package/templates/documentation/.cursorrules/adr.md +277 -0
- package/templates/documentation/.cursorrules/api-documentation.md +411 -0
- package/templates/documentation/.cursorrules/code-comments.md +253 -0
- package/templates/documentation/.cursorrules/maintenance.md +260 -0
- package/templates/documentation/.cursorrules/overview.md +82 -0
- package/templates/documentation/.cursorrules/readme-standards.md +306 -0
- package/templates/documentation/CLAUDE.md +120 -0
- package/templates/fullstack/.cursorrules/api-contracts.md +331 -0
- package/templates/fullstack/.cursorrules/architecture.md +298 -0
- package/templates/fullstack/.cursorrules/overview.md +109 -0
- package/templates/fullstack/.cursorrules/shared-types.md +348 -0
- package/templates/fullstack/.cursorrules/testing.md +386 -0
- package/templates/fullstack/CLAUDE.md +349 -0
- package/templates/ml-ai/.cursorrules/data-engineering.md +483 -0
- package/templates/ml-ai/.cursorrules/deployment.md +601 -0
- package/templates/ml-ai/.cursorrules/model-development.md +538 -0
- package/templates/ml-ai/.cursorrules/monitoring.md +658 -0
- package/templates/ml-ai/.cursorrules/overview.md +131 -0
- package/templates/ml-ai/.cursorrules/security.md +637 -0
- package/templates/ml-ai/.cursorrules/testing.md +678 -0
- package/templates/ml-ai/CLAUDE.md +1136 -0
- package/templates/mobile/.cursorrules/navigation.md +246 -0
- package/templates/mobile/.cursorrules/offline-first.md +302 -0
- package/templates/mobile/.cursorrules/overview.md +71 -0
- package/templates/mobile/.cursorrules/performance.md +345 -0
- package/templates/mobile/.cursorrules/testing.md +339 -0
- package/templates/mobile/CLAUDE.md +233 -0
- package/templates/platform-engineering/.cursorrules/ci-cd.md +778 -0
- package/templates/platform-engineering/.cursorrules/developer-experience.md +632 -0
- package/templates/platform-engineering/.cursorrules/infrastructure-as-code.md +600 -0
- package/templates/platform-engineering/.cursorrules/kubernetes.md +710 -0
- package/templates/platform-engineering/.cursorrules/observability.md +747 -0
- package/templates/platform-engineering/.cursorrules/overview.md +215 -0
- package/templates/platform-engineering/.cursorrules/security.md +855 -0
- package/templates/platform-engineering/.cursorrules/testing.md +878 -0
- package/templates/platform-engineering/CLAUDE.md +850 -0
- package/templates/utility-agent/.cursorrules/action-control.md +284 -0
- package/templates/utility-agent/.cursorrules/context-management.md +186 -0
- package/templates/utility-agent/.cursorrules/hallucination-prevention.md +253 -0
- package/templates/utility-agent/.cursorrules/overview.md +78 -0
- package/templates/utility-agent/.cursorrules/token-optimization.md +369 -0
- package/templates/utility-agent/CLAUDE.md +513 -0
- package/templates/web-backend/.cursorrules/api-design.md +255 -0
- package/templates/web-backend/.cursorrules/authentication.md +309 -0
- package/templates/web-backend/.cursorrules/database-patterns.md +298 -0
- package/templates/web-backend/.cursorrules/error-handling.md +366 -0
- package/templates/web-backend/.cursorrules/overview.md +69 -0
- package/templates/web-backend/.cursorrules/security.md +358 -0
- package/templates/web-backend/.cursorrules/testing.md +395 -0
- package/templates/web-backend/CLAUDE.md +366 -0
- package/templates/web-frontend/.cursorrules/accessibility.md +296 -0
- package/templates/web-frontend/.cursorrules/component-patterns.md +204 -0
- package/templates/web-frontend/.cursorrules/overview.md +72 -0
- package/templates/web-frontend/.cursorrules/performance.md +325 -0
- package/templates/web-frontend/.cursorrules/state-management.md +227 -0
- package/templates/web-frontend/.cursorrules/styling.md +271 -0
- package/templates/web-frontend/.cursorrules/testing.md +311 -0
- package/templates/web-frontend/CLAUDE.md +399 -0
|
@@ -0,0 +1,538 @@
|
|
|
1
|
+
# Model Development
|
|
2
|
+
|
|
3
|
+
Guidelines for training, experimentation, evaluation, and hyperparameter optimization in machine learning projects.
|
|
4
|
+
|
|
5
|
+
## Experiment Tracking
|
|
6
|
+
|
|
7
|
+
### MLflow Integration
|
|
8
|
+
|
|
9
|
+
```python
|
|
10
|
+
import mlflow
|
|
11
|
+
from mlflow.tracking import MlflowClient
|
|
12
|
+
|
|
13
|
+
def train_with_tracking(config: TrainingConfig, train_data, val_data) -> str:
|
|
14
|
+
"""Train model with comprehensive experiment tracking."""
|
|
15
|
+
|
|
16
|
+
mlflow.set_experiment(config.experiment_name)
|
|
17
|
+
|
|
18
|
+
with mlflow.start_run(run_name=config.run_name) as run:
|
|
19
|
+
# Log configuration
|
|
20
|
+
mlflow.log_params({
|
|
21
|
+
"model_type": config.model_type,
|
|
22
|
+
"learning_rate": config.learning_rate,
|
|
23
|
+
"batch_size": config.batch_size,
|
|
24
|
+
"epochs": config.epochs,
|
|
25
|
+
"optimizer": config.optimizer,
|
|
26
|
+
"loss_function": config.loss_function,
|
|
27
|
+
})
|
|
28
|
+
|
|
29
|
+
# Log data information
|
|
30
|
+
mlflow.log_params({
|
|
31
|
+
"train_samples": len(train_data),
|
|
32
|
+
"val_samples": len(val_data),
|
|
33
|
+
"feature_count": train_data.shape[1],
|
|
34
|
+
"label_distribution": dict(train_data["label"].value_counts()),
|
|
35
|
+
})
|
|
36
|
+
|
|
37
|
+
# Log environment
|
|
38
|
+
mlflow.log_params({
|
|
39
|
+
"python_version": sys.version,
|
|
40
|
+
"torch_version": torch.__version__,
|
|
41
|
+
"cuda_available": torch.cuda.is_available(),
|
|
42
|
+
})
|
|
43
|
+
|
|
44
|
+
# Train with metric logging
|
|
45
|
+
model = create_model(config)
|
|
46
|
+
for epoch in range(config.epochs):
|
|
47
|
+
train_loss = train_epoch(model, train_data)
|
|
48
|
+
val_metrics = evaluate(model, val_data)
|
|
49
|
+
|
|
50
|
+
mlflow.log_metrics({
|
|
51
|
+
"train_loss": train_loss,
|
|
52
|
+
"val_loss": val_metrics["loss"],
|
|
53
|
+
"val_accuracy": val_metrics["accuracy"],
|
|
54
|
+
"val_f1": val_metrics["f1"],
|
|
55
|
+
}, step=epoch)
|
|
56
|
+
|
|
57
|
+
# Log final model with signature
|
|
58
|
+
signature = mlflow.models.infer_signature(
|
|
59
|
+
train_data.drop("label", axis=1).head(),
|
|
60
|
+
model.predict(train_data.drop("label", axis=1).head())
|
|
61
|
+
)
|
|
62
|
+
mlflow.pytorch.log_model(model, "model", signature=signature)
|
|
63
|
+
|
|
64
|
+
# Log artifacts
|
|
65
|
+
mlflow.log_artifact("configs/training_config.yaml")
|
|
66
|
+
|
|
67
|
+
# Log custom artifacts
|
|
68
|
+
fig = plot_confusion_matrix(model, val_data)
|
|
69
|
+
mlflow.log_figure(fig, "confusion_matrix.png")
|
|
70
|
+
|
|
71
|
+
return run.info.run_id
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
### Experiment Organization
|
|
75
|
+
|
|
76
|
+
```python
|
|
77
|
+
# Hierarchical experiment structure
|
|
78
|
+
mlflow.set_experiment("fraud-detection/v2/feature-experiments")
|
|
79
|
+
|
|
80
|
+
# Tagging for filtering
|
|
81
|
+
with mlflow.start_run() as run:
|
|
82
|
+
mlflow.set_tags({
|
|
83
|
+
"team": "ml-platform",
|
|
84
|
+
"model_family": "gradient_boosting",
|
|
85
|
+
"data_version": "2024-01-15",
|
|
86
|
+
"experiment_type": "hyperparameter_search",
|
|
87
|
+
})
|
|
88
|
+
|
|
89
|
+
# Query experiments
|
|
90
|
+
client = MlflowClient()
|
|
91
|
+
runs = client.search_runs(
|
|
92
|
+
experiment_ids=["1"],
|
|
93
|
+
filter_string="metrics.val_f1 > 0.85 AND params.model_type = 'xgboost'",
|
|
94
|
+
order_by=["metrics.val_f1 DESC"],
|
|
95
|
+
max_results=10,
|
|
96
|
+
)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
## Evaluation Metrics
|
|
100
|
+
|
|
101
|
+
### Classification Metrics
|
|
102
|
+
|
|
103
|
+
```python
|
|
104
|
+
from dataclasses import dataclass
|
|
105
|
+
from sklearn.metrics import (
|
|
106
|
+
accuracy_score, precision_recall_fscore_support,
|
|
107
|
+
roc_auc_score, average_precision_score,
|
|
108
|
+
confusion_matrix, classification_report
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
@dataclass
|
|
112
|
+
class ClassificationMetrics:
|
|
113
|
+
"""Comprehensive classification evaluation."""
|
|
114
|
+
accuracy: float
|
|
115
|
+
precision: float
|
|
116
|
+
recall: float
|
|
117
|
+
f1: float
|
|
118
|
+
roc_auc: float
|
|
119
|
+
pr_auc: float
|
|
120
|
+
confusion_matrix: np.ndarray
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def compute(
|
|
124
|
+
cls,
|
|
125
|
+
y_true: np.ndarray,
|
|
126
|
+
y_pred: np.ndarray,
|
|
127
|
+
y_prob: np.ndarray
|
|
128
|
+
) -> "ClassificationMetrics":
|
|
129
|
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
|
130
|
+
y_true, y_pred, average="binary"
|
|
131
|
+
)
|
|
132
|
+
return cls(
|
|
133
|
+
accuracy=accuracy_score(y_true, y_pred),
|
|
134
|
+
precision=precision,
|
|
135
|
+
recall=recall,
|
|
136
|
+
f1=f1,
|
|
137
|
+
roc_auc=roc_auc_score(y_true, y_prob),
|
|
138
|
+
pr_auc=average_precision_score(y_true, y_prob),
|
|
139
|
+
confusion_matrix=confusion_matrix(y_true, y_pred),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def to_dict(self) -> dict[str, float]:
|
|
143
|
+
return {
|
|
144
|
+
"accuracy": self.accuracy,
|
|
145
|
+
"precision": self.precision,
|
|
146
|
+
"recall": self.recall,
|
|
147
|
+
"f1": self.f1,
|
|
148
|
+
"roc_auc": self.roc_auc,
|
|
149
|
+
"pr_auc": self.pr_auc,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
def summary(self) -> str:
|
|
153
|
+
return f"""
|
|
154
|
+
Accuracy: {self.accuracy:.4f}
|
|
155
|
+
Precision: {self.precision:.4f}
|
|
156
|
+
Recall: {self.recall:.4f}
|
|
157
|
+
F1 Score: {self.f1:.4f}
|
|
158
|
+
ROC AUC: {self.roc_auc:.4f}
|
|
159
|
+
PR AUC: {self.pr_auc:.4f}
|
|
160
|
+
"""
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
### Regression Metrics
|
|
164
|
+
|
|
165
|
+
```python
|
|
166
|
+
@dataclass
|
|
167
|
+
class RegressionMetrics:
|
|
168
|
+
"""Comprehensive regression evaluation."""
|
|
169
|
+
mse: float
|
|
170
|
+
rmse: float
|
|
171
|
+
mae: float
|
|
172
|
+
mape: float
|
|
173
|
+
r2: float
|
|
174
|
+
explained_variance: float
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def compute(cls, y_true: np.ndarray, y_pred: np.ndarray) -> "RegressionMetrics":
|
|
178
|
+
mse = mean_squared_error(y_true, y_pred)
|
|
179
|
+
return cls(
|
|
180
|
+
mse=mse,
|
|
181
|
+
rmse=np.sqrt(mse),
|
|
182
|
+
mae=mean_absolute_error(y_true, y_pred),
|
|
183
|
+
mape=mean_absolute_percentage_error(y_true, y_pred),
|
|
184
|
+
r2=r2_score(y_true, y_pred),
|
|
185
|
+
explained_variance=explained_variance_score(y_true, y_pred),
|
|
186
|
+
)
|
|
187
|
+
```
|
|
188
|
+
|
|
189
|
+
### Business Metrics
|
|
190
|
+
|
|
191
|
+
```python
|
|
192
|
+
def compute_business_metrics(
|
|
193
|
+
y_true: np.ndarray,
|
|
194
|
+
y_pred: np.ndarray,
|
|
195
|
+
amounts: np.ndarray,
|
|
196
|
+
cost_matrix: dict
|
|
197
|
+
) -> dict:
|
|
198
|
+
"""Compute business-relevant metrics."""
|
|
199
|
+
|
|
200
|
+
# Cost-sensitive evaluation
|
|
201
|
+
tp_mask = (y_true == 1) & (y_pred == 1)
|
|
202
|
+
fp_mask = (y_true == 0) & (y_pred == 1)
|
|
203
|
+
fn_mask = (y_true == 1) & (y_pred == 0)
|
|
204
|
+
tn_mask = (y_true == 0) & (y_pred == 0)
|
|
205
|
+
|
|
206
|
+
# Fraud detection example
|
|
207
|
+
fraud_caught = amounts[tp_mask].sum()
|
|
208
|
+
fraud_missed = amounts[fn_mask].sum()
|
|
209
|
+
false_alarm_cost = len(amounts[fp_mask]) * cost_matrix["investigation_cost"]
|
|
210
|
+
|
|
211
|
+
return {
|
|
212
|
+
"fraud_caught_amount": fraud_caught,
|
|
213
|
+
"fraud_missed_amount": fraud_missed,
|
|
214
|
+
"false_alarm_cost": false_alarm_cost,
|
|
215
|
+
"net_savings": fraud_caught - false_alarm_cost,
|
|
216
|
+
"precision_at_k": precision_at_k(y_true, y_pred, k=100),
|
|
217
|
+
"lift_at_k": lift_at_k(y_true, y_pred, k=100),
|
|
218
|
+
}
|
|
219
|
+
```
|
|
220
|
+
|
|
221
|
+
### Threshold Optimization
|
|
222
|
+
|
|
223
|
+
```python
|
|
224
|
+
def optimize_threshold(
|
|
225
|
+
y_true: np.ndarray,
|
|
226
|
+
y_prob: np.ndarray,
|
|
227
|
+
metric: str = "f1",
|
|
228
|
+
constraints: dict = None
|
|
229
|
+
) -> float:
|
|
230
|
+
"""Find optimal classification threshold."""
|
|
231
|
+
|
|
232
|
+
thresholds = np.linspace(0.01, 0.99, 99)
|
|
233
|
+
best_threshold = 0.5
|
|
234
|
+
best_score = 0
|
|
235
|
+
|
|
236
|
+
for threshold in thresholds:
|
|
237
|
+
y_pred = (y_prob >= threshold).astype(int)
|
|
238
|
+
|
|
239
|
+
# Check constraints
|
|
240
|
+
if constraints:
|
|
241
|
+
precision = precision_score(y_true, y_pred)
|
|
242
|
+
recall = recall_score(y_true, y_pred)
|
|
243
|
+
|
|
244
|
+
if "min_precision" in constraints and precision < constraints["min_precision"]:
|
|
245
|
+
continue
|
|
246
|
+
if "min_recall" in constraints and recall < constraints["min_recall"]:
|
|
247
|
+
continue
|
|
248
|
+
|
|
249
|
+
# Compute target metric
|
|
250
|
+
if metric == "f1":
|
|
251
|
+
score = f1_score(y_true, y_pred)
|
|
252
|
+
elif metric == "precision":
|
|
253
|
+
score = precision_score(y_true, y_pred)
|
|
254
|
+
elif metric == "recall":
|
|
255
|
+
score = recall_score(y_true, y_pred)
|
|
256
|
+
|
|
257
|
+
if score > best_score:
|
|
258
|
+
best_score = score
|
|
259
|
+
best_threshold = threshold
|
|
260
|
+
|
|
261
|
+
return best_threshold
|
|
262
|
+
```
|
|
263
|
+
|
|
264
|
+
## Hyperparameter Optimization
|
|
265
|
+
|
|
266
|
+
### Optuna Integration
|
|
267
|
+
|
|
268
|
+
```python
|
|
269
|
+
import optuna
|
|
270
|
+
from optuna.integration import MLflowCallback
|
|
271
|
+
|
|
272
|
+
def optimize_hyperparameters(
|
|
273
|
+
train_data: pd.DataFrame,
|
|
274
|
+
val_data: pd.DataFrame,
|
|
275
|
+
n_trials: int = 100,
|
|
276
|
+
timeout: int = 3600,
|
|
277
|
+
) -> dict:
|
|
278
|
+
"""Optimize hyperparameters with Optuna."""
|
|
279
|
+
|
|
280
|
+
def objective(trial: optuna.Trial) -> float:
|
|
281
|
+
# Define search space
|
|
282
|
+
params = {
|
|
283
|
+
"n_estimators": trial.suggest_int("n_estimators", 100, 1000, step=100),
|
|
284
|
+
"max_depth": trial.suggest_int("max_depth", 3, 12),
|
|
285
|
+
"learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-1, log=True),
|
|
286
|
+
"subsample": trial.suggest_float("subsample", 0.6, 1.0),
|
|
287
|
+
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.6, 1.0),
|
|
288
|
+
"min_child_weight": trial.suggest_int("min_child_weight", 1, 10),
|
|
289
|
+
"reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 1.0, log=True),
|
|
290
|
+
"reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 1.0, log=True),
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
# Train with early stopping
|
|
294
|
+
model = XGBClassifier(**params, early_stopping_rounds=50, n_jobs=-1)
|
|
295
|
+
model.fit(
|
|
296
|
+
train_data.drop("label", axis=1),
|
|
297
|
+
train_data["label"],
|
|
298
|
+
eval_set=[(val_data.drop("label", axis=1), val_data["label"])],
|
|
299
|
+
verbose=False,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Evaluate
|
|
303
|
+
y_prob = model.predict_proba(val_data.drop("label", axis=1))[:, 1]
|
|
304
|
+
return roc_auc_score(val_data["label"], y_prob)
|
|
305
|
+
|
|
306
|
+
# Create study with pruning
|
|
307
|
+
study = optuna.create_study(
|
|
308
|
+
direction="maximize",
|
|
309
|
+
pruner=optuna.pruners.MedianPruner(n_warmup_steps=10),
|
|
310
|
+
sampler=optuna.samplers.TPESampler(seed=42),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Optimize with MLflow logging
|
|
314
|
+
study.optimize(
|
|
315
|
+
objective,
|
|
316
|
+
n_trials=n_trials,
|
|
317
|
+
timeout=timeout,
|
|
318
|
+
callbacks=[MLflowCallback(metric_name="val_roc_auc")],
|
|
319
|
+
show_progress_bar=True,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
return study.best_params
|
|
323
|
+
```
|
|
324
|
+
|
|
325
|
+
### Neural Network Hyperparameters
|
|
326
|
+
|
|
327
|
+
```python
|
|
328
|
+
def objective_nn(trial: optuna.Trial) -> float:
|
|
329
|
+
"""Optimize neural network architecture and training."""
|
|
330
|
+
|
|
331
|
+
# Architecture search
|
|
332
|
+
n_layers = trial.suggest_int("n_layers", 2, 5)
|
|
333
|
+
layers = []
|
|
334
|
+
in_features = INPUT_DIM
|
|
335
|
+
|
|
336
|
+
for i in range(n_layers):
|
|
337
|
+
out_features = trial.suggest_int(f"n_units_l{i}", 32, 512, log=True)
|
|
338
|
+
layers.append(nn.Linear(in_features, out_features))
|
|
339
|
+
layers.append(nn.ReLU())
|
|
340
|
+
|
|
341
|
+
dropout = trial.suggest_float(f"dropout_l{i}", 0.1, 0.5)
|
|
342
|
+
layers.append(nn.Dropout(dropout))
|
|
343
|
+
|
|
344
|
+
in_features = out_features
|
|
345
|
+
|
|
346
|
+
layers.append(nn.Linear(in_features, NUM_CLASSES))
|
|
347
|
+
model = nn.Sequential(*layers)
|
|
348
|
+
|
|
349
|
+
# Optimizer selection
|
|
350
|
+
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW", "SGD"])
|
|
351
|
+
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
|
|
352
|
+
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
|
|
353
|
+
|
|
354
|
+
if optimizer_name == "Adam":
|
|
355
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
356
|
+
elif optimizer_name == "AdamW":
|
|
357
|
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
358
|
+
else:
|
|
359
|
+
momentum = trial.suggest_float("momentum", 0.8, 0.99)
|
|
360
|
+
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
|
|
361
|
+
|
|
362
|
+
# Training
|
|
363
|
+
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256])
|
|
364
|
+
|
|
365
|
+
for epoch in range(MAX_EPOCHS):
|
|
366
|
+
train_loss = train_epoch(model, optimizer, train_loader, batch_size)
|
|
367
|
+
val_acc = evaluate(model, val_loader)
|
|
368
|
+
|
|
369
|
+
# Pruning
|
|
370
|
+
trial.report(val_acc, epoch)
|
|
371
|
+
if trial.should_prune():
|
|
372
|
+
raise optuna.TrialPruned()
|
|
373
|
+
|
|
374
|
+
return val_acc
|
|
375
|
+
```
|
|
376
|
+
|
|
377
|
+
## Model Selection
|
|
378
|
+
|
|
379
|
+
### Cross-Validation
|
|
380
|
+
|
|
381
|
+
```python
|
|
382
|
+
from sklearn.model_selection import StratifiedKFold, cross_val_score
|
|
383
|
+
|
|
384
|
+
def evaluate_with_cv(
|
|
385
|
+
model,
|
|
386
|
+
X: pd.DataFrame,
|
|
387
|
+
y: pd.Series,
|
|
388
|
+
n_splits: int = 5,
|
|
389
|
+
scoring: str = "roc_auc"
|
|
390
|
+
) -> dict:
|
|
391
|
+
"""Evaluate model with stratified cross-validation."""
|
|
392
|
+
|
|
393
|
+
cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
|
394
|
+
|
|
395
|
+
scores = cross_val_score(model, X, y, cv=cv, scoring=scoring, n_jobs=-1)
|
|
396
|
+
|
|
397
|
+
return {
|
|
398
|
+
"mean": scores.mean(),
|
|
399
|
+
"std": scores.std(),
|
|
400
|
+
"min": scores.min(),
|
|
401
|
+
"max": scores.max(),
|
|
402
|
+
"scores": scores.tolist(),
|
|
403
|
+
}
|
|
404
|
+
```
|
|
405
|
+
|
|
406
|
+
### Model Comparison
|
|
407
|
+
|
|
408
|
+
```python
|
|
409
|
+
def compare_models(
|
|
410
|
+
models: dict[str, Any],
|
|
411
|
+
X_train: pd.DataFrame,
|
|
412
|
+
y_train: pd.Series,
|
|
413
|
+
X_test: pd.DataFrame,
|
|
414
|
+
y_test: pd.Series,
|
|
415
|
+
) -> pd.DataFrame:
|
|
416
|
+
"""Compare multiple models on the same data."""
|
|
417
|
+
|
|
418
|
+
results = []
|
|
419
|
+
|
|
420
|
+
for name, model in models.items():
|
|
421
|
+
start_time = time.time()
|
|
422
|
+
|
|
423
|
+
# Train
|
|
424
|
+
model.fit(X_train, y_train)
|
|
425
|
+
train_time = time.time() - start_time
|
|
426
|
+
|
|
427
|
+
# Predict
|
|
428
|
+
start_time = time.time()
|
|
429
|
+
y_pred = model.predict(X_test)
|
|
430
|
+
y_prob = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") else None
|
|
431
|
+
inference_time = time.time() - start_time
|
|
432
|
+
|
|
433
|
+
# Evaluate
|
|
434
|
+
metrics = ClassificationMetrics.compute(y_test, y_pred, y_prob)
|
|
435
|
+
|
|
436
|
+
results.append({
|
|
437
|
+
"model": name,
|
|
438
|
+
"accuracy": metrics.accuracy,
|
|
439
|
+
"precision": metrics.precision,
|
|
440
|
+
"recall": metrics.recall,
|
|
441
|
+
"f1": metrics.f1,
|
|
442
|
+
"roc_auc": metrics.roc_auc if y_prob is not None else None,
|
|
443
|
+
"train_time_s": train_time,
|
|
444
|
+
"inference_time_s": inference_time,
|
|
445
|
+
})
|
|
446
|
+
|
|
447
|
+
return pd.DataFrame(results).sort_values("f1", ascending=False)
|
|
448
|
+
```
|
|
449
|
+
|
|
450
|
+
## Model Registry
|
|
451
|
+
|
|
452
|
+
### MLflow Model Registry
|
|
453
|
+
|
|
454
|
+
```python
|
|
455
|
+
from mlflow.tracking import MlflowClient
|
|
456
|
+
|
|
457
|
+
def register_model(run_id: str, model_name: str, stage: str = "Staging") -> str:
|
|
458
|
+
"""Register model in MLflow Model Registry."""
|
|
459
|
+
|
|
460
|
+
client = MlflowClient()
|
|
461
|
+
|
|
462
|
+
# Register model version
|
|
463
|
+
model_uri = f"runs:/{run_id}/model"
|
|
464
|
+
result = mlflow.register_model(model_uri, model_name)
|
|
465
|
+
|
|
466
|
+
# Add metadata
|
|
467
|
+
client.update_model_version(
|
|
468
|
+
name=model_name,
|
|
469
|
+
version=result.version,
|
|
470
|
+
description=f"Model trained on {datetime.now().isoformat()}",
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# Set tags
|
|
474
|
+
client.set_model_version_tag(
|
|
475
|
+
name=model_name,
|
|
476
|
+
version=result.version,
|
|
477
|
+
key="validation_status",
|
|
478
|
+
value="pending",
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Transition to stage
|
|
482
|
+
client.transition_model_version_stage(
|
|
483
|
+
name=model_name,
|
|
484
|
+
version=result.version,
|
|
485
|
+
stage=stage,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
return result.version
|
|
489
|
+
|
|
490
|
+
def promote_model(model_name: str, version: str) -> None:
|
|
491
|
+
"""Promote model from Staging to Production."""
|
|
492
|
+
|
|
493
|
+
client = MlflowClient()
|
|
494
|
+
|
|
495
|
+
# Archive current production model
|
|
496
|
+
prod_versions = client.get_latest_versions(model_name, stages=["Production"])
|
|
497
|
+
for v in prod_versions:
|
|
498
|
+
client.transition_model_version_stage(
|
|
499
|
+
name=model_name,
|
|
500
|
+
version=v.version,
|
|
501
|
+
stage="Archived",
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
# Promote new version
|
|
505
|
+
client.transition_model_version_stage(
|
|
506
|
+
name=model_name,
|
|
507
|
+
version=version,
|
|
508
|
+
stage="Production",
|
|
509
|
+
)
|
|
510
|
+
```
|
|
511
|
+
|
|
512
|
+
## Best Practices
|
|
513
|
+
|
|
514
|
+
### Reproducibility Checklist
|
|
515
|
+
|
|
516
|
+
- [ ] Set random seeds for all sources of randomness
|
|
517
|
+
- [ ] Pin all dependency versions
|
|
518
|
+
- [ ] Log all hyperparameters and configurations
|
|
519
|
+
- [ ] Version training data
|
|
520
|
+
- [ ] Use deterministic algorithms where possible
|
|
521
|
+
- [ ] Log environment information (GPU, CUDA version)
|
|
522
|
+
|
|
523
|
+
### Evaluation Checklist
|
|
524
|
+
|
|
525
|
+
- [ ] Evaluate on held-out test set (not used for validation)
|
|
526
|
+
- [ ] Report multiple metrics, not just accuracy
|
|
527
|
+
- [ ] Include confidence intervals or standard deviations
|
|
528
|
+
- [ ] Evaluate across relevant segments/slices
|
|
529
|
+
- [ ] Compare against meaningful baselines
|
|
530
|
+
- [ ] Check for overfitting (train vs val vs test gaps)
|
|
531
|
+
|
|
532
|
+
### Model Selection Checklist
|
|
533
|
+
|
|
534
|
+
- [ ] Consider inference latency requirements
|
|
535
|
+
- [ ] Evaluate model size and memory footprint
|
|
536
|
+
- [ ] Test with realistic input distributions
|
|
537
|
+
- [ ] Validate on out-of-distribution samples
|
|
538
|
+
- [ ] Document model limitations
|