omgkit 2.19.3 → 2.21.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +537 -338
- package/package.json +2 -2
- package/plugin/agents/ai-architect-agent.md +282 -0
- package/plugin/agents/data-scientist-agent.md +221 -0
- package/plugin/agents/experiment-analyst-agent.md +318 -0
- package/plugin/agents/ml-engineer-agent.md +165 -0
- package/plugin/agents/mlops-engineer-agent.md +324 -0
- package/plugin/agents/model-optimizer-agent.md +287 -0
- package/plugin/agents/production-engineer-agent.md +360 -0
- package/plugin/agents/research-scientist-agent.md +274 -0
- package/plugin/commands/omgdata/augment.md +86 -0
- package/plugin/commands/omgdata/collect.md +81 -0
- package/plugin/commands/omgdata/label.md +83 -0
- package/plugin/commands/omgdata/split.md +83 -0
- package/plugin/commands/omgdata/validate.md +76 -0
- package/plugin/commands/omgdata/version.md +85 -0
- package/plugin/commands/omgdeploy/ab.md +94 -0
- package/plugin/commands/omgdeploy/cloud.md +89 -0
- package/plugin/commands/omgdeploy/edge.md +93 -0
- package/plugin/commands/omgdeploy/package.md +91 -0
- package/plugin/commands/omgdeploy/serve.md +92 -0
- package/plugin/commands/omgfeature/embed.md +93 -0
- package/plugin/commands/omgfeature/extract.md +93 -0
- package/plugin/commands/omgfeature/select.md +85 -0
- package/plugin/commands/omgfeature/store.md +97 -0
- package/plugin/commands/omgml/init.md +60 -0
- package/plugin/commands/omgml/status.md +82 -0
- package/plugin/commands/omgops/drift.md +87 -0
- package/plugin/commands/omgops/monitor.md +99 -0
- package/plugin/commands/omgops/pipeline.md +102 -0
- package/plugin/commands/omgops/registry.md +109 -0
- package/plugin/commands/omgops/retrain.md +91 -0
- package/plugin/commands/omgoptim/distill.md +90 -0
- package/plugin/commands/omgoptim/profile.md +92 -0
- package/plugin/commands/omgoptim/prune.md +81 -0
- package/plugin/commands/omgoptim/quantize.md +83 -0
- package/plugin/commands/omgtrain/baseline.md +78 -0
- package/plugin/commands/omgtrain/compare.md +99 -0
- package/plugin/commands/omgtrain/evaluate.md +85 -0
- package/plugin/commands/omgtrain/train.md +81 -0
- package/plugin/commands/omgtrain/tune.md +89 -0
- package/plugin/registry.yaml +252 -2
- package/plugin/skills/ml-systems/SKILL.md +65 -0
- package/plugin/skills/ml-systems/ai-accelerators/SKILL.md +342 -0
- package/plugin/skills/ml-systems/data-eng/SKILL.md +126 -0
- package/plugin/skills/ml-systems/deep-learning-primer/SKILL.md +143 -0
- package/plugin/skills/ml-systems/deployment-paradigms/SKILL.md +148 -0
- package/plugin/skills/ml-systems/dnn-architectures/SKILL.md +128 -0
- package/plugin/skills/ml-systems/edge-deployment/SKILL.md +366 -0
- package/plugin/skills/ml-systems/efficient-ai/SKILL.md +316 -0
- package/plugin/skills/ml-systems/feature-engineering/SKILL.md +151 -0
- package/plugin/skills/ml-systems/ml-frameworks/SKILL.md +187 -0
- package/plugin/skills/ml-systems/ml-serving-optimization/SKILL.md +371 -0
- package/plugin/skills/ml-systems/ml-systems-fundamentals/SKILL.md +103 -0
- package/plugin/skills/ml-systems/ml-workflow/SKILL.md +162 -0
- package/plugin/skills/ml-systems/mlops/SKILL.md +386 -0
- package/plugin/skills/ml-systems/model-deployment/SKILL.md +350 -0
- package/plugin/skills/ml-systems/model-dev/SKILL.md +160 -0
- package/plugin/skills/ml-systems/model-optimization/SKILL.md +339 -0
- package/plugin/skills/ml-systems/robust-ai/SKILL.md +395 -0
- package/plugin/skills/ml-systems/training-data/SKILL.md +152 -0
- package/plugin/workflows/ml-systems/data-preparation-workflow.md +276 -0
- package/plugin/workflows/ml-systems/edge-deployment-workflow.md +413 -0
- package/plugin/workflows/ml-systems/full-ml-lifecycle-workflow.md +405 -0
- package/plugin/workflows/ml-systems/hyperparameter-tuning-workflow.md +352 -0
- package/plugin/workflows/ml-systems/mlops-pipeline-workflow.md +384 -0
- package/plugin/workflows/ml-systems/model-deployment-workflow.md +392 -0
- package/plugin/workflows/ml-systems/model-development-workflow.md +218 -0
- package/plugin/workflows/ml-systems/model-evaluation-workflow.md +416 -0
- package/plugin/workflows/ml-systems/model-optimization-workflow.md +390 -0
- package/plugin/workflows/ml-systems/monitoring-drift-workflow.md +446 -0
- package/plugin/workflows/ml-systems/retraining-workflow.md +401 -0
- package/plugin/workflows/ml-systems/training-pipeline-workflow.md +382 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: robust-ai
|
|
3
|
+
description: Building robust AI systems including model monitoring, drift detection, reliability engineering, and failure handling for production ML.
|
|
4
|
+
---
|
|
5
|
+
|
|
6
|
+
# Robust AI
|
|
7
|
+
|
|
8
|
+
Building reliable and robust ML systems.
|
|
9
|
+
|
|
10
|
+
## Robustness Framework
|
|
11
|
+
|
|
12
|
+
```
|
|
13
|
+
┌─────────────────────────────────────────────────────────────┐
|
|
14
|
+
│ AI ROBUSTNESS LAYERS │
|
|
15
|
+
├─────────────────────────────────────────────────────────────┤
|
|
16
|
+
│ │
|
|
17
|
+
│ DATA QUALITY MODEL QUALITY SYSTEM QUALITY │
|
|
18
|
+
│ ──────────── ───────────── ────────────── │
|
|
19
|
+
│ Validation Testing Monitoring │
|
|
20
|
+
│ Anomaly detection Adversarial test Alerting │
|
|
21
|
+
│ Drift detection Uncertainty Fallbacks │
|
|
22
|
+
│ │
|
|
23
|
+
│ FAILURE MODES: │
|
|
24
|
+
│ ├── Data drift: Input distribution changes │
|
|
25
|
+
│ ├── Concept drift: Input-output relationship changes │
|
|
26
|
+
│ ├── Model degradation: Performance decline over time │
|
|
27
|
+
│ ├── Silent failures: Wrong predictions with high confidence│
|
|
28
|
+
│ └── System failures: Infrastructure and latency issues │
|
|
29
|
+
│ │
|
|
30
|
+
└─────────────────────────────────────────────────────────────┘
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
## Model Monitoring
|
|
34
|
+
|
|
35
|
+
### Prometheus + Grafana Setup
|
|
36
|
+
```python
|
|
37
|
+
from prometheus_client import Counter, Histogram, Gauge, start_http_server
|
|
38
|
+
|
|
39
|
+
# Metrics
|
|
40
|
+
PREDICTIONS = Counter('model_predictions_total', 'Total predictions', ['model', 'class'])
|
|
41
|
+
LATENCY = Histogram('model_latency_seconds', 'Prediction latency', ['model'])
|
|
42
|
+
CONFIDENCE = Histogram('model_confidence', 'Prediction confidence', ['model'], buckets=[0.5, 0.7, 0.9, 0.95, 0.99])
|
|
43
|
+
DRIFT_SCORE = Gauge('model_drift_score', 'Data drift score', ['model', 'feature'])
|
|
44
|
+
|
|
45
|
+
class MonitoredModel:
|
|
46
|
+
def __init__(self, model, model_name):
|
|
47
|
+
self.model = model
|
|
48
|
+
self.model_name = model_name
|
|
49
|
+
|
|
50
|
+
def predict(self, x):
|
|
51
|
+
with LATENCY.labels(model=self.model_name).time():
|
|
52
|
+
output = self.model(x)
|
|
53
|
+
|
|
54
|
+
probs = torch.softmax(output, dim=1)
|
|
55
|
+
pred_class = probs.argmax(dim=1).item()
|
|
56
|
+
confidence = probs.max().item()
|
|
57
|
+
|
|
58
|
+
PREDICTIONS.labels(model=self.model_name, class_=str(pred_class)).inc()
|
|
59
|
+
CONFIDENCE.labels(model=self.model_name).observe(confidence)
|
|
60
|
+
|
|
61
|
+
return output
|
|
62
|
+
|
|
63
|
+
# Start metrics server
|
|
64
|
+
start_http_server(8000)
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
### Evidently AI Monitoring
|
|
68
|
+
```python
|
|
69
|
+
from evidently import ColumnMapping
|
|
70
|
+
from evidently.report import Report
|
|
71
|
+
from evidently.metric_preset import DataDriftPreset, DataQualityPreset
|
|
72
|
+
from evidently.metrics import ColumnDriftMetric, DatasetDriftMetric
|
|
73
|
+
|
|
74
|
+
# Define column mapping
|
|
75
|
+
column_mapping = ColumnMapping(
|
|
76
|
+
target='target',
|
|
77
|
+
prediction='prediction',
|
|
78
|
+
numerical_features=['age', 'income', 'score'],
|
|
79
|
+
categorical_features=['category', 'region']
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Create drift report
|
|
83
|
+
report = Report(metrics=[
|
|
84
|
+
DataDriftPreset(),
|
|
85
|
+
DataQualityPreset(),
|
|
86
|
+
ColumnDriftMetric(column_name='age'),
|
|
87
|
+
DatasetDriftMetric()
|
|
88
|
+
])
|
|
89
|
+
|
|
90
|
+
report.run(
|
|
91
|
+
reference_data=reference_df,
|
|
92
|
+
current_data=current_df,
|
|
93
|
+
column_mapping=column_mapping
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Save report
|
|
97
|
+
report.save_html('drift_report.html')
|
|
98
|
+
|
|
99
|
+
# Get drift results
|
|
100
|
+
results = report.as_dict()
|
|
101
|
+
drift_detected = results['metrics'][0]['result']['dataset_drift']
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
## Drift Detection
|
|
105
|
+
|
|
106
|
+
### Statistical Drift Detection
|
|
107
|
+
```python
|
|
108
|
+
from scipy import stats
|
|
109
|
+
import numpy as np
|
|
110
|
+
|
|
111
|
+
class DriftDetector:
|
|
112
|
+
def __init__(self, reference_data, significance_level=0.05):
|
|
113
|
+
self.reference = reference_data
|
|
114
|
+
self.significance = significance_level
|
|
115
|
+
|
|
116
|
+
def detect_drift(self, current_data):
|
|
117
|
+
results = {}
|
|
118
|
+
|
|
119
|
+
for col in self.reference.columns:
|
|
120
|
+
if self.reference[col].dtype in ['float64', 'int64']:
|
|
121
|
+
# Kolmogorov-Smirnov test for numerical
|
|
122
|
+
stat, p_value = stats.ks_2samp(
|
|
123
|
+
self.reference[col],
|
|
124
|
+
current_data[col]
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
# Chi-square test for categorical
|
|
128
|
+
ref_counts = self.reference[col].value_counts()
|
|
129
|
+
cur_counts = current_data[col].value_counts()
|
|
130
|
+
stat, p_value = stats.chisquare(cur_counts, ref_counts)
|
|
131
|
+
|
|
132
|
+
results[col] = {
|
|
133
|
+
'statistic': stat,
|
|
134
|
+
'p_value': p_value,
|
|
135
|
+
'drift_detected': p_value < self.significance
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
return results
|
|
139
|
+
|
|
140
|
+
# Population Stability Index (PSI)
|
|
141
|
+
def calculate_psi(reference, current, bins=10):
|
|
142
|
+
ref_counts, bin_edges = np.histogram(reference, bins=bins)
|
|
143
|
+
cur_counts, _ = np.histogram(current, bins=bin_edges)
|
|
144
|
+
|
|
145
|
+
ref_pct = ref_counts / len(reference)
|
|
146
|
+
cur_pct = cur_counts / len(current)
|
|
147
|
+
|
|
148
|
+
# Avoid division by zero
|
|
149
|
+
ref_pct = np.clip(ref_pct, 0.0001, None)
|
|
150
|
+
cur_pct = np.clip(cur_pct, 0.0001, None)
|
|
151
|
+
|
|
152
|
+
psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
|
|
153
|
+
|
|
154
|
+
return psi # PSI > 0.25 indicates significant drift
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
### Concept Drift Detection
|
|
158
|
+
```python
|
|
159
|
+
from river import drift
|
|
160
|
+
|
|
161
|
+
class ConceptDriftMonitor:
|
|
162
|
+
def __init__(self):
|
|
163
|
+
self.adwin = drift.ADWIN()
|
|
164
|
+
self.ddm = drift.DDM()
|
|
165
|
+
self.performance_window = []
|
|
166
|
+
|
|
167
|
+
def update(self, y_true, y_pred):
|
|
168
|
+
error = int(y_true != y_pred)
|
|
169
|
+
|
|
170
|
+
# ADWIN for gradual drift
|
|
171
|
+
self.adwin.update(error)
|
|
172
|
+
adwin_drift = self.adwin.drift_detected
|
|
173
|
+
|
|
174
|
+
# DDM for sudden drift
|
|
175
|
+
self.ddm.update(error)
|
|
176
|
+
ddm_drift = self.ddm.drift_detected
|
|
177
|
+
|
|
178
|
+
return {
|
|
179
|
+
'adwin_drift': adwin_drift,
|
|
180
|
+
'ddm_drift': ddm_drift,
|
|
181
|
+
'error_rate': self.adwin.estimation
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
# Performance-based drift detection
|
|
185
|
+
class PerformanceDriftDetector:
|
|
186
|
+
def __init__(self, window_size=1000, threshold=0.1):
|
|
187
|
+
self.window_size = window_size
|
|
188
|
+
self.threshold = threshold
|
|
189
|
+
self.baseline_accuracy = None
|
|
190
|
+
self.current_window = []
|
|
191
|
+
|
|
192
|
+
def update(self, y_true, y_pred):
|
|
193
|
+
self.current_window.append(int(y_true == y_pred))
|
|
194
|
+
|
|
195
|
+
if len(self.current_window) >= self.window_size:
|
|
196
|
+
current_accuracy = np.mean(self.current_window[-self.window_size:])
|
|
197
|
+
|
|
198
|
+
if self.baseline_accuracy is None:
|
|
199
|
+
self.baseline_accuracy = current_accuracy
|
|
200
|
+
|
|
201
|
+
drift_detected = (self.baseline_accuracy - current_accuracy) > self.threshold
|
|
202
|
+
|
|
203
|
+
return {
|
|
204
|
+
'baseline': self.baseline_accuracy,
|
|
205
|
+
'current': current_accuracy,
|
|
206
|
+
'drift_detected': drift_detected
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
return None
|
|
210
|
+
```
|
|
211
|
+
|
|
212
|
+
## Uncertainty Estimation
|
|
213
|
+
|
|
214
|
+
```python
|
|
215
|
+
import torch
|
|
216
|
+
import torch.nn as nn
|
|
217
|
+
import torch.nn.functional as F
|
|
218
|
+
|
|
219
|
+
class MCDropoutModel(nn.Module):
|
|
220
|
+
"""Monte Carlo Dropout for uncertainty estimation."""
|
|
221
|
+
def __init__(self, base_model, dropout_rate=0.1):
|
|
222
|
+
super().__init__()
|
|
223
|
+
self.base_model = base_model
|
|
224
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
225
|
+
|
|
226
|
+
def forward(self, x, num_samples=30):
|
|
227
|
+
self.train() # Enable dropout
|
|
228
|
+
outputs = []
|
|
229
|
+
|
|
230
|
+
for _ in range(num_samples):
|
|
231
|
+
out = self.base_model(x)
|
|
232
|
+
out = self.dropout(out)
|
|
233
|
+
outputs.append(out)
|
|
234
|
+
|
|
235
|
+
outputs = torch.stack(outputs)
|
|
236
|
+
|
|
237
|
+
mean = outputs.mean(dim=0)
|
|
238
|
+
variance = outputs.var(dim=0)
|
|
239
|
+
epistemic_uncertainty = variance.mean(dim=-1)
|
|
240
|
+
|
|
241
|
+
return mean, epistemic_uncertainty
|
|
242
|
+
|
|
243
|
+
# Deep Ensembles
|
|
244
|
+
class EnsembleModel:
|
|
245
|
+
def __init__(self, models):
|
|
246
|
+
self.models = models
|
|
247
|
+
|
|
248
|
+
def predict_with_uncertainty(self, x):
|
|
249
|
+
predictions = []
|
|
250
|
+
for model in self.models:
|
|
251
|
+
model.eval()
|
|
252
|
+
with torch.no_grad():
|
|
253
|
+
pred = model(x)
|
|
254
|
+
predictions.append(pred)
|
|
255
|
+
|
|
256
|
+
predictions = torch.stack(predictions)
|
|
257
|
+
mean = predictions.mean(dim=0)
|
|
258
|
+
variance = predictions.var(dim=0)
|
|
259
|
+
|
|
260
|
+
return mean, variance
|
|
261
|
+
|
|
262
|
+
# Calibration (Temperature Scaling)
|
|
263
|
+
class TemperatureScaling(nn.Module):
|
|
264
|
+
def __init__(self, model):
|
|
265
|
+
super().__init__()
|
|
266
|
+
self.model = model
|
|
267
|
+
self.temperature = nn.Parameter(torch.ones(1))
|
|
268
|
+
|
|
269
|
+
def forward(self, x):
|
|
270
|
+
logits = self.model(x)
|
|
271
|
+
return logits / self.temperature
|
|
272
|
+
|
|
273
|
+
def calibrate(self, val_loader):
|
|
274
|
+
nll_criterion = nn.CrossEntropyLoss()
|
|
275
|
+
optimizer = torch.optim.LBFGS([self.temperature], lr=0.01, max_iter=50)
|
|
276
|
+
|
|
277
|
+
def eval_loss():
|
|
278
|
+
optimizer.zero_grad()
|
|
279
|
+
total_loss = 0
|
|
280
|
+
for x, y in val_loader:
|
|
281
|
+
logits = self.forward(x)
|
|
282
|
+
loss = nll_criterion(logits, y)
|
|
283
|
+
total_loss += loss
|
|
284
|
+
total_loss.backward()
|
|
285
|
+
return total_loss
|
|
286
|
+
|
|
287
|
+
optimizer.step(eval_loss)
|
|
288
|
+
```
|
|
289
|
+
|
|
290
|
+
## Fallback Strategies
|
|
291
|
+
|
|
292
|
+
```python
|
|
293
|
+
class RobustInferenceService:
|
|
294
|
+
def __init__(self, primary_model, fallback_model, confidence_threshold=0.7):
|
|
295
|
+
self.primary = primary_model
|
|
296
|
+
self.fallback = fallback_model
|
|
297
|
+
self.threshold = confidence_threshold
|
|
298
|
+
self.rule_based_fallback = RuleBasedModel()
|
|
299
|
+
|
|
300
|
+
def predict(self, x):
|
|
301
|
+
try:
|
|
302
|
+
# Try primary model
|
|
303
|
+
output = self.primary(x)
|
|
304
|
+
confidence = torch.softmax(output, dim=1).max().item()
|
|
305
|
+
|
|
306
|
+
if confidence >= self.threshold:
|
|
307
|
+
return {
|
|
308
|
+
'prediction': output.argmax().item(),
|
|
309
|
+
'confidence': confidence,
|
|
310
|
+
'model': 'primary'
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
# Low confidence - use fallback
|
|
314
|
+
output = self.fallback(x)
|
|
315
|
+
confidence = torch.softmax(output, dim=1).max().item()
|
|
316
|
+
|
|
317
|
+
if confidence >= self.threshold * 0.8:
|
|
318
|
+
return {
|
|
319
|
+
'prediction': output.argmax().item(),
|
|
320
|
+
'confidence': confidence,
|
|
321
|
+
'model': 'fallback'
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
# Still low confidence - use rules
|
|
325
|
+
return {
|
|
326
|
+
'prediction': self.rule_based_fallback(x),
|
|
327
|
+
'confidence': None,
|
|
328
|
+
'model': 'rule_based'
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
except Exception as e:
|
|
332
|
+
# System failure - use cached/default
|
|
333
|
+
return {
|
|
334
|
+
'prediction': self.get_default_prediction(),
|
|
335
|
+
'confidence': None,
|
|
336
|
+
'model': 'default',
|
|
337
|
+
'error': str(e)
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
def get_default_prediction(self):
|
|
341
|
+
# Return most common class or safe default
|
|
342
|
+
return 0
|
|
343
|
+
```
|
|
344
|
+
|
|
345
|
+
## Automated Retraining
|
|
346
|
+
|
|
347
|
+
```python
|
|
348
|
+
class AutoRetrainTrigger:
|
|
349
|
+
def __init__(self, drift_threshold=0.2, accuracy_threshold=0.85):
|
|
350
|
+
self.drift_threshold = drift_threshold
|
|
351
|
+
self.accuracy_threshold = accuracy_threshold
|
|
352
|
+
self.metrics_history = []
|
|
353
|
+
|
|
354
|
+
def should_retrain(self, metrics):
|
|
355
|
+
self.metrics_history.append(metrics)
|
|
356
|
+
|
|
357
|
+
# Check data drift
|
|
358
|
+
if metrics.get('drift_score', 0) > self.drift_threshold:
|
|
359
|
+
return True, 'data_drift'
|
|
360
|
+
|
|
361
|
+
# Check accuracy degradation
|
|
362
|
+
if metrics.get('accuracy', 1.0) < self.accuracy_threshold:
|
|
363
|
+
return True, 'accuracy_drop'
|
|
364
|
+
|
|
365
|
+
# Check trend
|
|
366
|
+
if len(self.metrics_history) >= 7:
|
|
367
|
+
recent = [m['accuracy'] for m in self.metrics_history[-7:]]
|
|
368
|
+
if all(recent[i] < recent[i-1] for i in range(1, len(recent))):
|
|
369
|
+
return True, 'declining_trend'
|
|
370
|
+
|
|
371
|
+
return False, None
|
|
372
|
+
|
|
373
|
+
def trigger_retrain(self, reason):
|
|
374
|
+
# Trigger retraining pipeline
|
|
375
|
+
from airflow.api.client.local_client import Client
|
|
376
|
+
client = Client(None, None)
|
|
377
|
+
client.trigger_dag(
|
|
378
|
+
dag_id='model_retraining',
|
|
379
|
+
conf={'trigger_reason': reason}
|
|
380
|
+
)
|
|
381
|
+
```
|
|
382
|
+
|
|
383
|
+
## Commands
|
|
384
|
+
- `/omgops:monitor` - Setup monitoring
|
|
385
|
+
- `/omgops:drift` - Drift detection
|
|
386
|
+
- `/omgops:retrain` - Trigger retraining
|
|
387
|
+
- `/omgtrain:evaluate` - Evaluate model
|
|
388
|
+
|
|
389
|
+
## Best Practices
|
|
390
|
+
|
|
391
|
+
1. Monitor predictions, not just system metrics
|
|
392
|
+
2. Set up automated drift detection
|
|
393
|
+
3. Implement graceful degradation
|
|
394
|
+
4. Use uncertainty estimation
|
|
395
|
+
5. Have clear retraining triggers
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: training-data
|
|
3
|
+
description: Training data management including labeling strategies, data augmentation, handling imbalanced data, and data splitting best practices.
|
|
4
|
+
---
|
|
5
|
+
|
|
6
|
+
# Training Data
|
|
7
|
+
|
|
8
|
+
Managing and improving training data quality.
|
|
9
|
+
|
|
10
|
+
## Data Labeling Strategies
|
|
11
|
+
|
|
12
|
+
### Manual Labeling
|
|
13
|
+
```python
|
|
14
|
+
# Export for Label Studio
|
|
15
|
+
def export_for_labeling(data: pd.DataFrame, output_path: str):
|
|
16
|
+
tasks = [
|
|
17
|
+
{"data": {"text": row["text"]}, "id": idx}
|
|
18
|
+
for idx, row in data.iterrows()
|
|
19
|
+
]
|
|
20
|
+
with open(output_path, 'w') as f:
|
|
21
|
+
json.dump(tasks, f)
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
### Weak Supervision
|
|
25
|
+
```python
|
|
26
|
+
from snorkel.labeling import labeling_function, LabelingFunction
|
|
27
|
+
from snorkel.labeling.model import LabelModel
|
|
28
|
+
|
|
29
|
+
@labeling_function()
|
|
30
|
+
def lf_keyword(x):
|
|
31
|
+
keywords = ["urgent", "free", "winner"]
|
|
32
|
+
return 1 if any(k in x.text.lower() for k in keywords) else -1
|
|
33
|
+
|
|
34
|
+
@labeling_function()
|
|
35
|
+
def lf_short_text(x):
|
|
36
|
+
return 1 if len(x.text) < 20 else -1
|
|
37
|
+
|
|
38
|
+
# Combine weak labels
|
|
39
|
+
lfs = [lf_keyword, lf_short_text]
|
|
40
|
+
applier = PandasLFApplier(lfs)
|
|
41
|
+
L_train = applier.apply(df_train)
|
|
42
|
+
|
|
43
|
+
label_model = LabelModel(cardinality=2)
|
|
44
|
+
label_model.fit(L_train, n_epochs=100)
|
|
45
|
+
labels = label_model.predict(L_train)
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
### Active Learning
|
|
49
|
+
```python
|
|
50
|
+
from modAL.models import ActiveLearner
|
|
51
|
+
from sklearn.ensemble import RandomForestClassifier
|
|
52
|
+
|
|
53
|
+
learner = ActiveLearner(
|
|
54
|
+
estimator=RandomForestClassifier(),
|
|
55
|
+
query_strategy=uncertainty_sampling,
|
|
56
|
+
X_training=X_initial, y_training=y_initial
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
for _ in range(n_queries):
|
|
60
|
+
query_idx, query_instance = learner.query(X_pool)
|
|
61
|
+
# Human labels the instance
|
|
62
|
+
y_new = get_human_label(query_instance)
|
|
63
|
+
learner.teach(X_pool[query_idx], y_new)
|
|
64
|
+
X_pool = np.delete(X_pool, query_idx, axis=0)
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## Data Augmentation
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
# Text augmentation
|
|
71
|
+
import nlpaug.augmenter.word as naw
|
|
72
|
+
|
|
73
|
+
aug = naw.SynonymAug(aug_src='wordnet')
|
|
74
|
+
augmented = aug.augment("The quick brown fox jumps over the lazy dog")
|
|
75
|
+
|
|
76
|
+
# Image augmentation
|
|
77
|
+
import albumentations as A
|
|
78
|
+
|
|
79
|
+
transform = A.Compose([
|
|
80
|
+
A.RandomCrop(width=256, height=256),
|
|
81
|
+
A.HorizontalFlip(p=0.5),
|
|
82
|
+
A.RandomBrightnessContrast(p=0.2),
|
|
83
|
+
A.Normalize()
|
|
84
|
+
])
|
|
85
|
+
|
|
86
|
+
# Tabular augmentation (SMOTE)
|
|
87
|
+
from imblearn.over_sampling import SMOTE
|
|
88
|
+
|
|
89
|
+
smote = SMOTE(random_state=42)
|
|
90
|
+
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
## Handling Imbalanced Data
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
# Class weights
|
|
97
|
+
from sklearn.utils.class_weight import compute_class_weight
|
|
98
|
+
|
|
99
|
+
class_weights = compute_class_weight('balanced', classes=np.unique(y), y=y)
|
|
100
|
+
|
|
101
|
+
# Focal loss
|
|
102
|
+
def focal_loss(y_true, y_pred, gamma=2, alpha=0.25):
|
|
103
|
+
bce = F.binary_cross_entropy(y_pred, y_true, reduction='none')
|
|
104
|
+
p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
|
|
105
|
+
focal_weight = (1 - p_t) ** gamma
|
|
106
|
+
return (alpha * focal_weight * bce).mean()
|
|
107
|
+
|
|
108
|
+
# Stratified sampling
|
|
109
|
+
from sklearn.model_selection import StratifiedKFold
|
|
110
|
+
|
|
111
|
+
skf = StratifiedKFold(n_splits=5, shuffle=True)
|
|
112
|
+
for train_idx, val_idx in skf.split(X, y):
|
|
113
|
+
X_train, X_val = X[train_idx], X[val_idx]
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
## Data Splitting
|
|
117
|
+
|
|
118
|
+
```python
|
|
119
|
+
from sklearn.model_selection import train_test_split
|
|
120
|
+
|
|
121
|
+
# Random split (with stratification)
|
|
122
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
|
123
|
+
X, y, test_size=0.2, stratify=y, random_state=42
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Temporal split (for time series)
|
|
127
|
+
def temporal_split(df, time_col, train_end, val_end):
|
|
128
|
+
train = df[df[time_col] < train_end]
|
|
129
|
+
val = df[(df[time_col] >= train_end) & (df[time_col] < val_end)]
|
|
130
|
+
test = df[df[time_col] >= val_end]
|
|
131
|
+
return train, val, test
|
|
132
|
+
|
|
133
|
+
# Group split (no data leakage)
|
|
134
|
+
from sklearn.model_selection import GroupShuffleSplit
|
|
135
|
+
|
|
136
|
+
gss = GroupShuffleSplit(n_splits=1, test_size=0.2)
|
|
137
|
+
for train_idx, test_idx in gss.split(X, y, groups=user_ids):
|
|
138
|
+
X_train, X_test = X[train_idx], X[test_idx]
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
## Commands
|
|
142
|
+
- `/omgdata:label` - Data labeling
|
|
143
|
+
- `/omgdata:augment` - Augmentation
|
|
144
|
+
- `/omgdata:split` - Data splitting
|
|
145
|
+
|
|
146
|
+
## Best Practices
|
|
147
|
+
|
|
148
|
+
1. Start with a small, high-quality labeled set
|
|
149
|
+
2. Use weak supervision to scale labeling
|
|
150
|
+
3. Match augmentation to your domain
|
|
151
|
+
4. Prevent data leakage in splits
|
|
152
|
+
5. Monitor label quality over time
|