mcli-framework 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcli-framework might be problematic. Click here for more details.
- mcli/app/chat_cmd.py +42 -0
- mcli/app/commands_cmd.py +226 -0
- mcli/app/completion_cmd.py +216 -0
- mcli/app/completion_helpers.py +288 -0
- mcli/app/cron_test_cmd.py +697 -0
- mcli/app/logs_cmd.py +419 -0
- mcli/app/main.py +492 -0
- mcli/app/model/model.py +1060 -0
- mcli/app/model_cmd.py +227 -0
- mcli/app/redis_cmd.py +269 -0
- mcli/app/video/video.py +1114 -0
- mcli/app/visual_cmd.py +303 -0
- mcli/chat/chat.py +2409 -0
- mcli/chat/command_rag.py +514 -0
- mcli/chat/enhanced_chat.py +652 -0
- mcli/chat/system_controller.py +1010 -0
- mcli/chat/system_integration.py +1016 -0
- mcli/cli.py +25 -0
- mcli/config.toml +20 -0
- mcli/lib/api/api.py +586 -0
- mcli/lib/api/daemon_client.py +203 -0
- mcli/lib/api/daemon_client_local.py +44 -0
- mcli/lib/api/daemon_decorator.py +217 -0
- mcli/lib/api/mcli_decorators.py +1032 -0
- mcli/lib/auth/auth.py +85 -0
- mcli/lib/auth/aws_manager.py +85 -0
- mcli/lib/auth/azure_manager.py +91 -0
- mcli/lib/auth/credential_manager.py +192 -0
- mcli/lib/auth/gcp_manager.py +93 -0
- mcli/lib/auth/key_manager.py +117 -0
- mcli/lib/auth/mcli_manager.py +93 -0
- mcli/lib/auth/token_manager.py +75 -0
- mcli/lib/auth/token_util.py +1011 -0
- mcli/lib/config/config.py +47 -0
- mcli/lib/discovery/__init__.py +1 -0
- mcli/lib/discovery/command_discovery.py +274 -0
- mcli/lib/erd/erd.py +1345 -0
- mcli/lib/erd/generate_graph.py +453 -0
- mcli/lib/files/files.py +76 -0
- mcli/lib/fs/fs.py +109 -0
- mcli/lib/lib.py +29 -0
- mcli/lib/logger/logger.py +611 -0
- mcli/lib/performance/optimizer.py +409 -0
- mcli/lib/performance/rust_bridge.py +502 -0
- mcli/lib/performance/uvloop_config.py +154 -0
- mcli/lib/pickles/pickles.py +50 -0
- mcli/lib/search/cached_vectorizer.py +479 -0
- mcli/lib/services/data_pipeline.py +460 -0
- mcli/lib/services/lsh_client.py +441 -0
- mcli/lib/services/redis_service.py +387 -0
- mcli/lib/shell/shell.py +137 -0
- mcli/lib/toml/toml.py +33 -0
- mcli/lib/ui/styling.py +47 -0
- mcli/lib/ui/visual_effects.py +634 -0
- mcli/lib/watcher/watcher.py +185 -0
- mcli/ml/api/app.py +215 -0
- mcli/ml/api/middleware.py +224 -0
- mcli/ml/api/routers/admin_router.py +12 -0
- mcli/ml/api/routers/auth_router.py +244 -0
- mcli/ml/api/routers/backtest_router.py +12 -0
- mcli/ml/api/routers/data_router.py +12 -0
- mcli/ml/api/routers/model_router.py +302 -0
- mcli/ml/api/routers/monitoring_router.py +12 -0
- mcli/ml/api/routers/portfolio_router.py +12 -0
- mcli/ml/api/routers/prediction_router.py +267 -0
- mcli/ml/api/routers/trade_router.py +12 -0
- mcli/ml/api/routers/websocket_router.py +76 -0
- mcli/ml/api/schemas.py +64 -0
- mcli/ml/auth/auth_manager.py +425 -0
- mcli/ml/auth/models.py +154 -0
- mcli/ml/auth/permissions.py +302 -0
- mcli/ml/backtesting/backtest_engine.py +502 -0
- mcli/ml/backtesting/performance_metrics.py +393 -0
- mcli/ml/cache.py +400 -0
- mcli/ml/cli/main.py +398 -0
- mcli/ml/config/settings.py +394 -0
- mcli/ml/configs/dvc_config.py +230 -0
- mcli/ml/configs/mlflow_config.py +131 -0
- mcli/ml/configs/mlops_manager.py +293 -0
- mcli/ml/dashboard/app.py +532 -0
- mcli/ml/dashboard/app_integrated.py +738 -0
- mcli/ml/dashboard/app_supabase.py +560 -0
- mcli/ml/dashboard/app_training.py +615 -0
- mcli/ml/dashboard/cli.py +51 -0
- mcli/ml/data_ingestion/api_connectors.py +501 -0
- mcli/ml/data_ingestion/data_pipeline.py +567 -0
- mcli/ml/data_ingestion/stream_processor.py +512 -0
- mcli/ml/database/migrations/env.py +94 -0
- mcli/ml/database/models.py +667 -0
- mcli/ml/database/session.py +200 -0
- mcli/ml/experimentation/ab_testing.py +845 -0
- mcli/ml/features/ensemble_features.py +607 -0
- mcli/ml/features/political_features.py +676 -0
- mcli/ml/features/recommendation_engine.py +809 -0
- mcli/ml/features/stock_features.py +573 -0
- mcli/ml/features/test_feature_engineering.py +346 -0
- mcli/ml/logging.py +85 -0
- mcli/ml/mlops/data_versioning.py +518 -0
- mcli/ml/mlops/experiment_tracker.py +377 -0
- mcli/ml/mlops/model_serving.py +481 -0
- mcli/ml/mlops/pipeline_orchestrator.py +614 -0
- mcli/ml/models/base_models.py +324 -0
- mcli/ml/models/ensemble_models.py +675 -0
- mcli/ml/models/recommendation_models.py +474 -0
- mcli/ml/models/test_models.py +487 -0
- mcli/ml/monitoring/drift_detection.py +676 -0
- mcli/ml/monitoring/metrics.py +45 -0
- mcli/ml/optimization/portfolio_optimizer.py +834 -0
- mcli/ml/preprocessing/data_cleaners.py +451 -0
- mcli/ml/preprocessing/feature_extractors.py +491 -0
- mcli/ml/preprocessing/ml_pipeline.py +382 -0
- mcli/ml/preprocessing/politician_trading_preprocessor.py +569 -0
- mcli/ml/preprocessing/test_preprocessing.py +294 -0
- mcli/ml/scripts/populate_sample_data.py +200 -0
- mcli/ml/tasks.py +400 -0
- mcli/ml/tests/test_integration.py +429 -0
- mcli/ml/tests/test_training_dashboard.py +387 -0
- mcli/public/oi/oi.py +15 -0
- mcli/public/public.py +4 -0
- mcli/self/self_cmd.py +1246 -0
- mcli/workflow/daemon/api_daemon.py +800 -0
- mcli/workflow/daemon/async_command_database.py +681 -0
- mcli/workflow/daemon/async_process_manager.py +591 -0
- mcli/workflow/daemon/client.py +530 -0
- mcli/workflow/daemon/commands.py +1196 -0
- mcli/workflow/daemon/daemon.py +905 -0
- mcli/workflow/daemon/daemon_api.py +59 -0
- mcli/workflow/daemon/enhanced_daemon.py +571 -0
- mcli/workflow/daemon/process_cli.py +244 -0
- mcli/workflow/daemon/process_manager.py +439 -0
- mcli/workflow/daemon/test_daemon.py +275 -0
- mcli/workflow/dashboard/dashboard_cmd.py +113 -0
- mcli/workflow/docker/docker.py +0 -0
- mcli/workflow/file/file.py +100 -0
- mcli/workflow/gcloud/config.toml +21 -0
- mcli/workflow/gcloud/gcloud.py +58 -0
- mcli/workflow/git_commit/ai_service.py +328 -0
- mcli/workflow/git_commit/commands.py +430 -0
- mcli/workflow/lsh_integration.py +355 -0
- mcli/workflow/model_service/client.py +594 -0
- mcli/workflow/model_service/download_and_run_efficient_models.py +288 -0
- mcli/workflow/model_service/lightweight_embedder.py +397 -0
- mcli/workflow/model_service/lightweight_model_server.py +714 -0
- mcli/workflow/model_service/lightweight_test.py +241 -0
- mcli/workflow/model_service/model_service.py +1955 -0
- mcli/workflow/model_service/ollama_efficient_runner.py +425 -0
- mcli/workflow/model_service/pdf_processor.py +386 -0
- mcli/workflow/model_service/test_efficient_runner.py +234 -0
- mcli/workflow/model_service/test_example.py +315 -0
- mcli/workflow/model_service/test_integration.py +131 -0
- mcli/workflow/model_service/test_new_features.py +149 -0
- mcli/workflow/openai/openai.py +99 -0
- mcli/workflow/politician_trading/commands.py +1790 -0
- mcli/workflow/politician_trading/config.py +134 -0
- mcli/workflow/politician_trading/connectivity.py +490 -0
- mcli/workflow/politician_trading/data_sources.py +395 -0
- mcli/workflow/politician_trading/database.py +410 -0
- mcli/workflow/politician_trading/demo.py +248 -0
- mcli/workflow/politician_trading/models.py +165 -0
- mcli/workflow/politician_trading/monitoring.py +413 -0
- mcli/workflow/politician_trading/scrapers.py +966 -0
- mcli/workflow/politician_trading/scrapers_california.py +412 -0
- mcli/workflow/politician_trading/scrapers_eu.py +377 -0
- mcli/workflow/politician_trading/scrapers_uk.py +350 -0
- mcli/workflow/politician_trading/scrapers_us_states.py +438 -0
- mcli/workflow/politician_trading/supabase_functions.py +354 -0
- mcli/workflow/politician_trading/workflow.py +852 -0
- mcli/workflow/registry/registry.py +180 -0
- mcli/workflow/repo/repo.py +223 -0
- mcli/workflow/scheduler/commands.py +493 -0
- mcli/workflow/scheduler/cron_parser.py +238 -0
- mcli/workflow/scheduler/job.py +182 -0
- mcli/workflow/scheduler/monitor.py +139 -0
- mcli/workflow/scheduler/persistence.py +324 -0
- mcli/workflow/scheduler/scheduler.py +679 -0
- mcli/workflow/sync/sync_cmd.py +437 -0
- mcli/workflow/sync/test_cmd.py +314 -0
- mcli/workflow/videos/videos.py +242 -0
- mcli/workflow/wakatime/wakatime.py +11 -0
- mcli/workflow/workflow.py +37 -0
- mcli_framework-7.0.0.dist-info/METADATA +479 -0
- mcli_framework-7.0.0.dist-info/RECORD +186 -0
- mcli_framework-7.0.0.dist-info/WHEEL +5 -0
- mcli_framework-7.0.0.dist-info/entry_points.txt +7 -0
- mcli_framework-7.0.0.dist-info/licenses/LICENSE +21 -0
- mcli_framework-7.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
"""Base classes for ML models"""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from typing import Dict, List, Optional, Tuple, Any, Union
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ModelMetrics:
|
|
17
|
+
"""Container for model performance metrics"""
|
|
18
|
+
|
|
19
|
+
accuracy: float
|
|
20
|
+
precision: float
|
|
21
|
+
recall: float
|
|
22
|
+
f1_score: float
|
|
23
|
+
auc_roc: float
|
|
24
|
+
sharpe_ratio: Optional[float] = None
|
|
25
|
+
max_drawdown: Optional[float] = None
|
|
26
|
+
total_return: Optional[float] = None
|
|
27
|
+
win_rate: Optional[float] = None
|
|
28
|
+
avg_gain: Optional[float] = None
|
|
29
|
+
avg_loss: Optional[float] = None
|
|
30
|
+
|
|
31
|
+
def to_dict(self) -> Dict[str, float]:
|
|
32
|
+
"""Convert metrics to dictionary"""
|
|
33
|
+
return {
|
|
34
|
+
"accuracy": self.accuracy,
|
|
35
|
+
"precision": self.precision,
|
|
36
|
+
"recall": self.recall,
|
|
37
|
+
"f1_score": self.f1_score,
|
|
38
|
+
"auc_roc": self.auc_roc,
|
|
39
|
+
"sharpe_ratio": self.sharpe_ratio or 0.0,
|
|
40
|
+
"max_drawdown": self.max_drawdown or 0.0,
|
|
41
|
+
"total_return": self.total_return or 0.0,
|
|
42
|
+
"win_rate": self.win_rate or 0.0,
|
|
43
|
+
"avg_gain": self.avg_gain or 0.0,
|
|
44
|
+
"avg_loss": self.avg_loss or 0.0,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ValidationResult:
|
|
50
|
+
"""Container for validation results"""
|
|
51
|
+
|
|
52
|
+
train_metrics: ModelMetrics
|
|
53
|
+
val_metrics: ModelMetrics
|
|
54
|
+
test_metrics: Optional[ModelMetrics] = None
|
|
55
|
+
feature_importance: Optional[Dict[str, float]] = None
|
|
56
|
+
predictions: Optional[np.ndarray] = None
|
|
57
|
+
true_labels: Optional[np.ndarray] = None
|
|
58
|
+
training_history: Optional[Dict[str, List[float]]] = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class BaseStockModel(nn.Module, ABC):
|
|
62
|
+
"""Abstract base class for all stock prediction models"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, input_dim: int, config: Optional[Dict[str, Any]] = None):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.input_dim = input_dim
|
|
67
|
+
self.config = config or {}
|
|
68
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
69
|
+
self.is_trained = False
|
|
70
|
+
self.feature_names: Optional[List[str]] = None
|
|
71
|
+
self.scaler = None
|
|
72
|
+
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
75
|
+
"""Forward pass through the model"""
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def predict_proba(self, X: Union[torch.Tensor, np.ndarray, pd.DataFrame]) -> np.ndarray:
|
|
80
|
+
"""Predict class probabilities"""
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
def predict(self, X: Union[torch.Tensor, np.ndarray, pd.DataFrame]) -> np.ndarray:
|
|
84
|
+
"""Make binary predictions"""
|
|
85
|
+
probas = self.predict_proba(X)
|
|
86
|
+
return (probas[:, 1] > 0.5).astype(int)
|
|
87
|
+
|
|
88
|
+
def preprocess_input(self, X: Union[torch.Tensor, np.ndarray, pd.DataFrame]) -> torch.Tensor:
|
|
89
|
+
"""Preprocess input data for model"""
|
|
90
|
+
if isinstance(X, pd.DataFrame):
|
|
91
|
+
X = X.values
|
|
92
|
+
|
|
93
|
+
if isinstance(X, np.ndarray):
|
|
94
|
+
X = torch.FloatTensor(X)
|
|
95
|
+
|
|
96
|
+
# Apply scaling if available
|
|
97
|
+
if self.scaler is not None and not isinstance(X, torch.Tensor):
|
|
98
|
+
X = self.scaler.transform(X)
|
|
99
|
+
X = torch.FloatTensor(X)
|
|
100
|
+
|
|
101
|
+
return X.to(self.device)
|
|
102
|
+
|
|
103
|
+
def get_feature_importance(self) -> Optional[Dict[str, float]]:
|
|
104
|
+
"""Get feature importance scores"""
|
|
105
|
+
# Base implementation returns None
|
|
106
|
+
# Override in specific models that support feature importance
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
def save_model(self, path: str) -> None:
|
|
110
|
+
"""Save model state"""
|
|
111
|
+
state = {
|
|
112
|
+
"model_state_dict": self.state_dict(),
|
|
113
|
+
"config": self.config,
|
|
114
|
+
"input_dim": self.input_dim,
|
|
115
|
+
"feature_names": self.feature_names,
|
|
116
|
+
"scaler": self.scaler,
|
|
117
|
+
"is_trained": self.is_trained,
|
|
118
|
+
}
|
|
119
|
+
torch.save(state, path)
|
|
120
|
+
logger.info(f"Model saved to {path}")
|
|
121
|
+
|
|
122
|
+
def load_model(self, path: str) -> None:
|
|
123
|
+
"""Load model state"""
|
|
124
|
+
state = torch.load(path, map_location=self.device)
|
|
125
|
+
self.load_state_dict(state["model_state_dict"])
|
|
126
|
+
self.config = state["config"]
|
|
127
|
+
self.input_dim = state["input_dim"]
|
|
128
|
+
self.feature_names = state.get("feature_names")
|
|
129
|
+
self.scaler = state.get("scaler")
|
|
130
|
+
self.is_trained = state.get("is_trained", False)
|
|
131
|
+
logger.info(f"Model loaded from {path}")
|
|
132
|
+
|
|
133
|
+
def calculate_metrics(
|
|
134
|
+
self, y_true: np.ndarray, y_pred: np.ndarray, y_proba: Optional[np.ndarray] = None
|
|
135
|
+
) -> ModelMetrics:
|
|
136
|
+
"""Calculate comprehensive model metrics"""
|
|
137
|
+
from sklearn.metrics import (
|
|
138
|
+
accuracy_score,
|
|
139
|
+
precision_score,
|
|
140
|
+
recall_score,
|
|
141
|
+
f1_score,
|
|
142
|
+
roc_auc_score,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Basic classification metrics
|
|
146
|
+
accuracy = accuracy_score(y_true, y_pred)
|
|
147
|
+
precision = precision_score(y_true, y_pred, average="weighted", zero_division=0)
|
|
148
|
+
recall = recall_score(y_true, y_pred, average="weighted", zero_division=0)
|
|
149
|
+
f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
|
|
150
|
+
|
|
151
|
+
# AUC-ROC
|
|
152
|
+
auc_roc = 0.0
|
|
153
|
+
if y_proba is not None and len(np.unique(y_true)) > 1:
|
|
154
|
+
try:
|
|
155
|
+
if y_proba.ndim > 1 and y_proba.shape[1] > 1:
|
|
156
|
+
auc_roc = roc_auc_score(y_true, y_proba[:, 1])
|
|
157
|
+
else:
|
|
158
|
+
auc_roc = roc_auc_score(y_true, y_proba)
|
|
159
|
+
except Exception as e:
|
|
160
|
+
logger.warning(f"Could not calculate AUC-ROC: {e}")
|
|
161
|
+
|
|
162
|
+
# Trading-specific metrics (simplified)
|
|
163
|
+
win_rate = np.mean(y_pred == 1) if len(y_pred) > 0 else 0.0
|
|
164
|
+
|
|
165
|
+
return ModelMetrics(
|
|
166
|
+
accuracy=accuracy,
|
|
167
|
+
precision=precision,
|
|
168
|
+
recall=recall,
|
|
169
|
+
f1_score=f1,
|
|
170
|
+
auc_roc=auc_roc,
|
|
171
|
+
win_rate=win_rate,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def to(self, device):
|
|
175
|
+
"""Move model to device and update internal device reference"""
|
|
176
|
+
self.device = device
|
|
177
|
+
return super().to(device)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class MLPBaseModel(BaseStockModel):
|
|
181
|
+
"""Basic Multi-Layer Perceptron base model"""
|
|
182
|
+
|
|
183
|
+
def __init__(
|
|
184
|
+
self,
|
|
185
|
+
input_dim: int,
|
|
186
|
+
hidden_dims: List[int] = [512, 256, 128],
|
|
187
|
+
output_dim: int = 2,
|
|
188
|
+
dropout_rate: float = 0.3,
|
|
189
|
+
config: Optional[Dict[str, Any]] = None,
|
|
190
|
+
):
|
|
191
|
+
super().__init__(input_dim, config)
|
|
192
|
+
|
|
193
|
+
self.hidden_dims = hidden_dims
|
|
194
|
+
self.output_dim = output_dim
|
|
195
|
+
self.dropout_rate = dropout_rate
|
|
196
|
+
|
|
197
|
+
# Build network layers
|
|
198
|
+
layers = []
|
|
199
|
+
prev_dim = input_dim
|
|
200
|
+
|
|
201
|
+
for hidden_dim in hidden_dims:
|
|
202
|
+
layers.extend(
|
|
203
|
+
[
|
|
204
|
+
nn.Linear(prev_dim, hidden_dim),
|
|
205
|
+
nn.BatchNorm1d(hidden_dim),
|
|
206
|
+
nn.ReLU(),
|
|
207
|
+
nn.Dropout(dropout_rate),
|
|
208
|
+
]
|
|
209
|
+
)
|
|
210
|
+
prev_dim = hidden_dim
|
|
211
|
+
|
|
212
|
+
# Output layer
|
|
213
|
+
layers.append(nn.Linear(prev_dim, output_dim))
|
|
214
|
+
|
|
215
|
+
self.network = nn.Sequential(*layers)
|
|
216
|
+
self.softmax = nn.Softmax(dim=1)
|
|
217
|
+
|
|
218
|
+
# Initialize weights
|
|
219
|
+
self.apply(self._init_weights)
|
|
220
|
+
|
|
221
|
+
def _init_weights(self, module):
|
|
222
|
+
"""Initialize model weights"""
|
|
223
|
+
if isinstance(module, nn.Linear):
|
|
224
|
+
torch.nn.init.xavier_uniform_(module.weight)
|
|
225
|
+
if module.bias is not None:
|
|
226
|
+
torch.nn.init.zeros_(module.bias)
|
|
227
|
+
|
|
228
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
229
|
+
"""Forward pass"""
|
|
230
|
+
return self.network(x)
|
|
231
|
+
|
|
232
|
+
def predict_proba(self, X: Union[torch.Tensor, np.ndarray, pd.DataFrame]) -> np.ndarray:
|
|
233
|
+
"""Predict class probabilities"""
|
|
234
|
+
self.eval()
|
|
235
|
+
with torch.no_grad():
|
|
236
|
+
X_tensor = self.preprocess_input(X)
|
|
237
|
+
logits = self.forward(X_tensor)
|
|
238
|
+
probas = self.softmax(logits)
|
|
239
|
+
return probas.cpu().numpy()
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class ResidualBlock(nn.Module):
|
|
243
|
+
"""Residual block for deeper networks"""
|
|
244
|
+
|
|
245
|
+
def __init__(self, dim: int, dropout_rate: float = 0.1):
|
|
246
|
+
super().__init__()
|
|
247
|
+
self.block = nn.Sequential(
|
|
248
|
+
nn.Linear(dim, dim),
|
|
249
|
+
nn.BatchNorm1d(dim),
|
|
250
|
+
nn.ReLU(),
|
|
251
|
+
nn.Dropout(dropout_rate),
|
|
252
|
+
nn.Linear(dim, dim),
|
|
253
|
+
nn.BatchNorm1d(dim),
|
|
254
|
+
)
|
|
255
|
+
self.relu = nn.ReLU()
|
|
256
|
+
|
|
257
|
+
def forward(self, x):
|
|
258
|
+
identity = x
|
|
259
|
+
out = self.block(x)
|
|
260
|
+
out += identity
|
|
261
|
+
return self.relu(out)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class ResNetModel(BaseStockModel):
|
|
265
|
+
"""ResNet-style model for tabular data"""
|
|
266
|
+
|
|
267
|
+
def __init__(
|
|
268
|
+
self,
|
|
269
|
+
input_dim: int,
|
|
270
|
+
hidden_dim: int = 256,
|
|
271
|
+
num_blocks: int = 3,
|
|
272
|
+
output_dim: int = 2,
|
|
273
|
+
dropout_rate: float = 0.2,
|
|
274
|
+
config: Optional[Dict[str, Any]] = None,
|
|
275
|
+
):
|
|
276
|
+
super().__init__(input_dim, config)
|
|
277
|
+
|
|
278
|
+
self.hidden_dim = hidden_dim
|
|
279
|
+
self.num_blocks = num_blocks
|
|
280
|
+
self.output_dim = output_dim
|
|
281
|
+
|
|
282
|
+
# Input projection
|
|
283
|
+
self.input_proj = nn.Sequential(
|
|
284
|
+
nn.Linear(input_dim, hidden_dim),
|
|
285
|
+
nn.BatchNorm1d(hidden_dim),
|
|
286
|
+
nn.ReLU(),
|
|
287
|
+
nn.Dropout(dropout_rate),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Residual blocks
|
|
291
|
+
self.blocks = nn.ModuleList(
|
|
292
|
+
[ResidualBlock(hidden_dim, dropout_rate) for _ in range(num_blocks)]
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# Output layer
|
|
296
|
+
self.output_layer = nn.Linear(hidden_dim, output_dim)
|
|
297
|
+
self.softmax = nn.Softmax(dim=1)
|
|
298
|
+
|
|
299
|
+
self.apply(self._init_weights)
|
|
300
|
+
|
|
301
|
+
def _init_weights(self, module):
|
|
302
|
+
"""Initialize weights"""
|
|
303
|
+
if isinstance(module, nn.Linear):
|
|
304
|
+
torch.nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
|
305
|
+
if module.bias is not None:
|
|
306
|
+
torch.nn.init.zeros_(module.bias)
|
|
307
|
+
|
|
308
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
309
|
+
"""Forward pass"""
|
|
310
|
+
x = self.input_proj(x)
|
|
311
|
+
|
|
312
|
+
for block in self.blocks:
|
|
313
|
+
x = block(x)
|
|
314
|
+
|
|
315
|
+
return self.output_layer(x)
|
|
316
|
+
|
|
317
|
+
def predict_proba(self, X: Union[torch.Tensor, np.ndarray, pd.DataFrame]) -> np.ndarray:
|
|
318
|
+
"""Predict class probabilities"""
|
|
319
|
+
self.eval()
|
|
320
|
+
with torch.no_grad():
|
|
321
|
+
X_tensor = self.preprocess_input(X)
|
|
322
|
+
logits = self.forward(X_tensor)
|
|
323
|
+
probas = self.softmax(logits)
|
|
324
|
+
return probas.cpu().numpy()
|