mcli-framework 7.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcli-framework might be problematic. Click here for more details.
- mcli/app/chat_cmd.py +42 -0
- mcli/app/commands_cmd.py +226 -0
- mcli/app/completion_cmd.py +216 -0
- mcli/app/completion_helpers.py +288 -0
- mcli/app/cron_test_cmd.py +697 -0
- mcli/app/logs_cmd.py +419 -0
- mcli/app/main.py +492 -0
- mcli/app/model/model.py +1060 -0
- mcli/app/model_cmd.py +227 -0
- mcli/app/redis_cmd.py +269 -0
- mcli/app/video/video.py +1114 -0
- mcli/app/visual_cmd.py +303 -0
- mcli/chat/chat.py +2409 -0
- mcli/chat/command_rag.py +514 -0
- mcli/chat/enhanced_chat.py +652 -0
- mcli/chat/system_controller.py +1010 -0
- mcli/chat/system_integration.py +1016 -0
- mcli/cli.py +25 -0
- mcli/config.toml +20 -0
- mcli/lib/api/api.py +586 -0
- mcli/lib/api/daemon_client.py +203 -0
- mcli/lib/api/daemon_client_local.py +44 -0
- mcli/lib/api/daemon_decorator.py +217 -0
- mcli/lib/api/mcli_decorators.py +1032 -0
- mcli/lib/auth/auth.py +85 -0
- mcli/lib/auth/aws_manager.py +85 -0
- mcli/lib/auth/azure_manager.py +91 -0
- mcli/lib/auth/credential_manager.py +192 -0
- mcli/lib/auth/gcp_manager.py +93 -0
- mcli/lib/auth/key_manager.py +117 -0
- mcli/lib/auth/mcli_manager.py +93 -0
- mcli/lib/auth/token_manager.py +75 -0
- mcli/lib/auth/token_util.py +1011 -0
- mcli/lib/config/config.py +47 -0
- mcli/lib/discovery/__init__.py +1 -0
- mcli/lib/discovery/command_discovery.py +274 -0
- mcli/lib/erd/erd.py +1345 -0
- mcli/lib/erd/generate_graph.py +453 -0
- mcli/lib/files/files.py +76 -0
- mcli/lib/fs/fs.py +109 -0
- mcli/lib/lib.py +29 -0
- mcli/lib/logger/logger.py +611 -0
- mcli/lib/performance/optimizer.py +409 -0
- mcli/lib/performance/rust_bridge.py +502 -0
- mcli/lib/performance/uvloop_config.py +154 -0
- mcli/lib/pickles/pickles.py +50 -0
- mcli/lib/search/cached_vectorizer.py +479 -0
- mcli/lib/services/data_pipeline.py +460 -0
- mcli/lib/services/lsh_client.py +441 -0
- mcli/lib/services/redis_service.py +387 -0
- mcli/lib/shell/shell.py +137 -0
- mcli/lib/toml/toml.py +33 -0
- mcli/lib/ui/styling.py +47 -0
- mcli/lib/ui/visual_effects.py +634 -0
- mcli/lib/watcher/watcher.py +185 -0
- mcli/ml/api/app.py +215 -0
- mcli/ml/api/middleware.py +224 -0
- mcli/ml/api/routers/admin_router.py +12 -0
- mcli/ml/api/routers/auth_router.py +244 -0
- mcli/ml/api/routers/backtest_router.py +12 -0
- mcli/ml/api/routers/data_router.py +12 -0
- mcli/ml/api/routers/model_router.py +302 -0
- mcli/ml/api/routers/monitoring_router.py +12 -0
- mcli/ml/api/routers/portfolio_router.py +12 -0
- mcli/ml/api/routers/prediction_router.py +267 -0
- mcli/ml/api/routers/trade_router.py +12 -0
- mcli/ml/api/routers/websocket_router.py +76 -0
- mcli/ml/api/schemas.py +64 -0
- mcli/ml/auth/auth_manager.py +425 -0
- mcli/ml/auth/models.py +154 -0
- mcli/ml/auth/permissions.py +302 -0
- mcli/ml/backtesting/backtest_engine.py +502 -0
- mcli/ml/backtesting/performance_metrics.py +393 -0
- mcli/ml/cache.py +400 -0
- mcli/ml/cli/main.py +398 -0
- mcli/ml/config/settings.py +394 -0
- mcli/ml/configs/dvc_config.py +230 -0
- mcli/ml/configs/mlflow_config.py +131 -0
- mcli/ml/configs/mlops_manager.py +293 -0
- mcli/ml/dashboard/app.py +532 -0
- mcli/ml/dashboard/app_integrated.py +738 -0
- mcli/ml/dashboard/app_supabase.py +560 -0
- mcli/ml/dashboard/app_training.py +615 -0
- mcli/ml/dashboard/cli.py +51 -0
- mcli/ml/data_ingestion/api_connectors.py +501 -0
- mcli/ml/data_ingestion/data_pipeline.py +567 -0
- mcli/ml/data_ingestion/stream_processor.py +512 -0
- mcli/ml/database/migrations/env.py +94 -0
- mcli/ml/database/models.py +667 -0
- mcli/ml/database/session.py +200 -0
- mcli/ml/experimentation/ab_testing.py +845 -0
- mcli/ml/features/ensemble_features.py +607 -0
- mcli/ml/features/political_features.py +676 -0
- mcli/ml/features/recommendation_engine.py +809 -0
- mcli/ml/features/stock_features.py +573 -0
- mcli/ml/features/test_feature_engineering.py +346 -0
- mcli/ml/logging.py +85 -0
- mcli/ml/mlops/data_versioning.py +518 -0
- mcli/ml/mlops/experiment_tracker.py +377 -0
- mcli/ml/mlops/model_serving.py +481 -0
- mcli/ml/mlops/pipeline_orchestrator.py +614 -0
- mcli/ml/models/base_models.py +324 -0
- mcli/ml/models/ensemble_models.py +675 -0
- mcli/ml/models/recommendation_models.py +474 -0
- mcli/ml/models/test_models.py +487 -0
- mcli/ml/monitoring/drift_detection.py +676 -0
- mcli/ml/monitoring/metrics.py +45 -0
- mcli/ml/optimization/portfolio_optimizer.py +834 -0
- mcli/ml/preprocessing/data_cleaners.py +451 -0
- mcli/ml/preprocessing/feature_extractors.py +491 -0
- mcli/ml/preprocessing/ml_pipeline.py +382 -0
- mcli/ml/preprocessing/politician_trading_preprocessor.py +569 -0
- mcli/ml/preprocessing/test_preprocessing.py +294 -0
- mcli/ml/scripts/populate_sample_data.py +200 -0
- mcli/ml/tasks.py +400 -0
- mcli/ml/tests/test_integration.py +429 -0
- mcli/ml/tests/test_training_dashboard.py +387 -0
- mcli/public/oi/oi.py +15 -0
- mcli/public/public.py +4 -0
- mcli/self/self_cmd.py +1246 -0
- mcli/workflow/daemon/api_daemon.py +800 -0
- mcli/workflow/daemon/async_command_database.py +681 -0
- mcli/workflow/daemon/async_process_manager.py +591 -0
- mcli/workflow/daemon/client.py +530 -0
- mcli/workflow/daemon/commands.py +1196 -0
- mcli/workflow/daemon/daemon.py +905 -0
- mcli/workflow/daemon/daemon_api.py +59 -0
- mcli/workflow/daemon/enhanced_daemon.py +571 -0
- mcli/workflow/daemon/process_cli.py +244 -0
- mcli/workflow/daemon/process_manager.py +439 -0
- mcli/workflow/daemon/test_daemon.py +275 -0
- mcli/workflow/dashboard/dashboard_cmd.py +113 -0
- mcli/workflow/docker/docker.py +0 -0
- mcli/workflow/file/file.py +100 -0
- mcli/workflow/gcloud/config.toml +21 -0
- mcli/workflow/gcloud/gcloud.py +58 -0
- mcli/workflow/git_commit/ai_service.py +328 -0
- mcli/workflow/git_commit/commands.py +430 -0
- mcli/workflow/lsh_integration.py +355 -0
- mcli/workflow/model_service/client.py +594 -0
- mcli/workflow/model_service/download_and_run_efficient_models.py +288 -0
- mcli/workflow/model_service/lightweight_embedder.py +397 -0
- mcli/workflow/model_service/lightweight_model_server.py +714 -0
- mcli/workflow/model_service/lightweight_test.py +241 -0
- mcli/workflow/model_service/model_service.py +1955 -0
- mcli/workflow/model_service/ollama_efficient_runner.py +425 -0
- mcli/workflow/model_service/pdf_processor.py +386 -0
- mcli/workflow/model_service/test_efficient_runner.py +234 -0
- mcli/workflow/model_service/test_example.py +315 -0
- mcli/workflow/model_service/test_integration.py +131 -0
- mcli/workflow/model_service/test_new_features.py +149 -0
- mcli/workflow/openai/openai.py +99 -0
- mcli/workflow/politician_trading/commands.py +1790 -0
- mcli/workflow/politician_trading/config.py +134 -0
- mcli/workflow/politician_trading/connectivity.py +490 -0
- mcli/workflow/politician_trading/data_sources.py +395 -0
- mcli/workflow/politician_trading/database.py +410 -0
- mcli/workflow/politician_trading/demo.py +248 -0
- mcli/workflow/politician_trading/models.py +165 -0
- mcli/workflow/politician_trading/monitoring.py +413 -0
- mcli/workflow/politician_trading/scrapers.py +966 -0
- mcli/workflow/politician_trading/scrapers_california.py +412 -0
- mcli/workflow/politician_trading/scrapers_eu.py +377 -0
- mcli/workflow/politician_trading/scrapers_uk.py +350 -0
- mcli/workflow/politician_trading/scrapers_us_states.py +438 -0
- mcli/workflow/politician_trading/supabase_functions.py +354 -0
- mcli/workflow/politician_trading/workflow.py +852 -0
- mcli/workflow/registry/registry.py +180 -0
- mcli/workflow/repo/repo.py +223 -0
- mcli/workflow/scheduler/commands.py +493 -0
- mcli/workflow/scheduler/cron_parser.py +238 -0
- mcli/workflow/scheduler/job.py +182 -0
- mcli/workflow/scheduler/monitor.py +139 -0
- mcli/workflow/scheduler/persistence.py +324 -0
- mcli/workflow/scheduler/scheduler.py +679 -0
- mcli/workflow/sync/sync_cmd.py +437 -0
- mcli/workflow/sync/test_cmd.py +314 -0
- mcli/workflow/videos/videos.py +242 -0
- mcli/workflow/wakatime/wakatime.py +11 -0
- mcli/workflow/workflow.py +37 -0
- mcli_framework-7.0.0.dist-info/METADATA +479 -0
- mcli_framework-7.0.0.dist-info/RECORD +186 -0
- mcli_framework-7.0.0.dist-info/WHEEL +5 -0
- mcli_framework-7.0.0.dist-info/entry_points.txt +7 -0
- mcli_framework-7.0.0.dist-info/licenses/LICENSE +21 -0
- mcli_framework-7.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
"""Test script for ensemble models"""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../.."))
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import torch
|
|
11
|
+
from datetime import datetime, timedelta
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
# Set up logging
|
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def generate_mock_features(n_samples: int = 500, n_features: int = 150) -> pd.DataFrame:
|
|
20
|
+
"""Generate mock feature data for testing"""
|
|
21
|
+
np.random.seed(42)
|
|
22
|
+
|
|
23
|
+
# Create realistic feature names
|
|
24
|
+
feature_names = []
|
|
25
|
+
|
|
26
|
+
# Technical indicators
|
|
27
|
+
for indicator in ["sma", "ema", "rsi", "macd", "bb", "volume", "volatility"]:
|
|
28
|
+
for period in [5, 10, 20, 50]:
|
|
29
|
+
feature_names.append(f"{indicator}_{period}")
|
|
30
|
+
|
|
31
|
+
# Political features
|
|
32
|
+
for pol_feature in ["influence", "trading_freq", "committee_align", "seniority"]:
|
|
33
|
+
for agg in ["mean", "max", "std"]:
|
|
34
|
+
feature_names.append(f"political_{pol_feature}_{agg}")
|
|
35
|
+
|
|
36
|
+
# Ensemble features
|
|
37
|
+
for i in range(50):
|
|
38
|
+
feature_names.append(f"ensemble_feature_{i}")
|
|
39
|
+
|
|
40
|
+
# Market regime features
|
|
41
|
+
for regime in ["volatility", "trend", "volume"]:
|
|
42
|
+
for metric in ["regime", "strength", "persistence"]:
|
|
43
|
+
feature_names.append(f"market_{regime}_{metric}")
|
|
44
|
+
|
|
45
|
+
# Pad or trim to exact number
|
|
46
|
+
while len(feature_names) < n_features:
|
|
47
|
+
feature_names.append(f"extra_feature_{len(feature_names)}")
|
|
48
|
+
feature_names = feature_names[:n_features]
|
|
49
|
+
|
|
50
|
+
# Generate correlated features that simulate real market data
|
|
51
|
+
features = []
|
|
52
|
+
for i in range(n_samples):
|
|
53
|
+
# Base market trend
|
|
54
|
+
market_trend = np.random.normal(0, 1)
|
|
55
|
+
|
|
56
|
+
# Technical features (correlated with trend)
|
|
57
|
+
tech_features = np.random.normal(market_trend * 0.3, 0.8, 32)
|
|
58
|
+
|
|
59
|
+
# Political features (some correlation with market)
|
|
60
|
+
pol_features = np.random.normal(market_trend * 0.1, 0.5, 12)
|
|
61
|
+
|
|
62
|
+
# Ensemble features (mix of correlated and noise)
|
|
63
|
+
ensemble_features = np.random.normal(market_trend * 0.2, 0.6, 50)
|
|
64
|
+
|
|
65
|
+
# Market regime features
|
|
66
|
+
regime_features = np.random.normal(market_trend * 0.4, 0.7, 9)
|
|
67
|
+
|
|
68
|
+
# Extra random features
|
|
69
|
+
n_extra = max(0, n_features - 103) # Ensure non-negative
|
|
70
|
+
if n_extra > 0:
|
|
71
|
+
extra_features = np.random.normal(0, 0.5, n_extra)
|
|
72
|
+
sample_features = np.concatenate(
|
|
73
|
+
[tech_features, pol_features, ensemble_features, regime_features, extra_features]
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
# Truncate if we have too many features
|
|
77
|
+
all_features = np.concatenate([tech_features, pol_features, ensemble_features, regime_features])
|
|
78
|
+
sample_features = all_features[:n_features]
|
|
79
|
+
features.append(sample_features)
|
|
80
|
+
|
|
81
|
+
return pd.DataFrame(features, columns=feature_names)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def generate_mock_targets(n_samples: int) -> tuple:
|
|
85
|
+
"""Generate realistic target variables"""
|
|
86
|
+
np.random.seed(42)
|
|
87
|
+
|
|
88
|
+
# Generate correlated targets
|
|
89
|
+
market_performance = np.random.normal(0, 1, n_samples)
|
|
90
|
+
|
|
91
|
+
# Binary classification target (profitable vs not)
|
|
92
|
+
binary_targets = (market_performance > 0).astype(int)
|
|
93
|
+
|
|
94
|
+
# Continuous returns (with realistic distribution)
|
|
95
|
+
returns = np.random.normal(0.05, 0.15, n_samples) # 5% mean, 15% volatility
|
|
96
|
+
|
|
97
|
+
# Risk labels (low=0, medium=1, high=2)
|
|
98
|
+
risk_labels = np.random.choice([0, 1, 2], n_samples, p=[0.3, 0.5, 0.2])
|
|
99
|
+
|
|
100
|
+
return binary_targets, returns, risk_labels
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_base_models():
|
|
104
|
+
"""Test base model functionality"""
|
|
105
|
+
logger.info("Testing base models...")
|
|
106
|
+
|
|
107
|
+
from base_models import MLPBaseModel, ResNetModel
|
|
108
|
+
|
|
109
|
+
# Generate test data
|
|
110
|
+
X = generate_mock_features(100, 50)
|
|
111
|
+
y, _, _ = generate_mock_targets(100)
|
|
112
|
+
|
|
113
|
+
# Test MLP model
|
|
114
|
+
mlp_model = MLPBaseModel(input_dim=50, hidden_dims=[128, 64], dropout_rate=0.2)
|
|
115
|
+
|
|
116
|
+
# Test forward pass
|
|
117
|
+
X_tensor = torch.FloatTensor(X.values)
|
|
118
|
+
output = mlp_model(X_tensor)
|
|
119
|
+
logger.info(f"MLP output shape: {output.shape}")
|
|
120
|
+
|
|
121
|
+
# Test prediction
|
|
122
|
+
probas = mlp_model.predict_proba(X)
|
|
123
|
+
predictions = mlp_model.predict(X)
|
|
124
|
+
logger.info(f"MLP predictions shape: {predictions.shape}")
|
|
125
|
+
|
|
126
|
+
# Test ResNet model
|
|
127
|
+
resnet_model = ResNetModel(input_dim=50, hidden_dim=128, num_blocks=2)
|
|
128
|
+
|
|
129
|
+
output = resnet_model(X_tensor)
|
|
130
|
+
logger.info(f"ResNet output shape: {output.shape}")
|
|
131
|
+
|
|
132
|
+
# Test metrics calculation
|
|
133
|
+
metrics = mlp_model.calculate_metrics(y, predictions)
|
|
134
|
+
logger.info(f"Model metrics: {metrics}")
|
|
135
|
+
|
|
136
|
+
logger.info("✅ Base models test passed")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def test_ensemble_models():
|
|
140
|
+
"""Test ensemble model functionality"""
|
|
141
|
+
logger.info("Testing ensemble models...")
|
|
142
|
+
|
|
143
|
+
from ensemble_models import (
|
|
144
|
+
DeepEnsembleModel,
|
|
145
|
+
EnsembleConfig,
|
|
146
|
+
ModelConfig,
|
|
147
|
+
AttentionStockPredictor,
|
|
148
|
+
TransformerStockModel,
|
|
149
|
+
LSTMStockPredictor,
|
|
150
|
+
CNNFeatureExtractor,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Generate test data
|
|
154
|
+
X = generate_mock_features(200, 100)
|
|
155
|
+
y, _, _ = generate_mock_targets(200)
|
|
156
|
+
|
|
157
|
+
input_dim = X.shape[1]
|
|
158
|
+
|
|
159
|
+
# Test individual models
|
|
160
|
+
logger.info("Testing individual ensemble components...")
|
|
161
|
+
|
|
162
|
+
# Attention model
|
|
163
|
+
attention_model = AttentionStockPredictor(input_dim, hidden_dim=64, num_heads=4, num_layers=2)
|
|
164
|
+
output = attention_model(torch.FloatTensor(X.values[:10]))
|
|
165
|
+
logger.info(f"Attention model output shape: {output.shape}")
|
|
166
|
+
|
|
167
|
+
# Transformer model
|
|
168
|
+
transformer_model = TransformerStockModel(input_dim, d_model=64, nhead=4, num_layers=2)
|
|
169
|
+
output = transformer_model(torch.FloatTensor(X.values[:10]))
|
|
170
|
+
logger.info(f"Transformer model output shape: {output.shape}")
|
|
171
|
+
|
|
172
|
+
# LSTM model
|
|
173
|
+
lstm_model = LSTMStockPredictor(input_dim, hidden_dim=64, num_layers=2)
|
|
174
|
+
output = lstm_model(torch.FloatTensor(X.values[:10]))
|
|
175
|
+
logger.info(f"LSTM model output shape: {output.shape}")
|
|
176
|
+
|
|
177
|
+
# CNN model
|
|
178
|
+
cnn_model = CNNFeatureExtractor(input_dim, num_filters=32, filter_sizes=[3, 5])
|
|
179
|
+
output = cnn_model(torch.FloatTensor(X.values[:10]))
|
|
180
|
+
logger.info(f"CNN model output shape: {output.shape}")
|
|
181
|
+
|
|
182
|
+
# Test ensemble configuration
|
|
183
|
+
model_configs = [
|
|
184
|
+
ModelConfig(
|
|
185
|
+
model_type="attention",
|
|
186
|
+
hidden_dims=[128],
|
|
187
|
+
dropout_rate=0.2,
|
|
188
|
+
learning_rate=0.001,
|
|
189
|
+
weight_decay=1e-4,
|
|
190
|
+
batch_size=32,
|
|
191
|
+
epochs=5,
|
|
192
|
+
),
|
|
193
|
+
ModelConfig(
|
|
194
|
+
model_type="lstm",
|
|
195
|
+
hidden_dims=[128],
|
|
196
|
+
dropout_rate=0.2,
|
|
197
|
+
learning_rate=0.001,
|
|
198
|
+
weight_decay=1e-4,
|
|
199
|
+
batch_size=32,
|
|
200
|
+
epochs=5,
|
|
201
|
+
),
|
|
202
|
+
ModelConfig(
|
|
203
|
+
model_type="mlp",
|
|
204
|
+
hidden_dims=[256, 128],
|
|
205
|
+
dropout_rate=0.3,
|
|
206
|
+
learning_rate=0.001,
|
|
207
|
+
weight_decay=1e-4,
|
|
208
|
+
batch_size=32,
|
|
209
|
+
epochs=5,
|
|
210
|
+
),
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
ensemble_config = EnsembleConfig(
|
|
214
|
+
base_models=model_configs,
|
|
215
|
+
ensemble_method="weighted_average",
|
|
216
|
+
feature_subsampling=True,
|
|
217
|
+
bootstrap_samples=True,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Create ensemble model
|
|
221
|
+
ensemble_model = DeepEnsembleModel(input_dim, ensemble_config)
|
|
222
|
+
|
|
223
|
+
# Test forward pass
|
|
224
|
+
X_test = torch.FloatTensor(X.values[:20])
|
|
225
|
+
ensemble_output = ensemble_model(X_test)
|
|
226
|
+
logger.info(f"Ensemble output shape: {ensemble_output.shape}")
|
|
227
|
+
|
|
228
|
+
# Test individual predictions
|
|
229
|
+
individual_preds = ensemble_model.get_individual_predictions(X.values[:20])
|
|
230
|
+
logger.info(f"Individual predictions: {len(individual_preds)} models")
|
|
231
|
+
|
|
232
|
+
# Test prediction methods
|
|
233
|
+
ensemble_probas = ensemble_model.predict_proba(X.values[:20])
|
|
234
|
+
ensemble_preds = ensemble_model.predict(X.values[:20])
|
|
235
|
+
logger.info(f"Ensemble predictions shape: {ensemble_preds.shape}")
|
|
236
|
+
|
|
237
|
+
logger.info("✅ Ensemble models test passed")
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def test_recommendation_model():
|
|
241
|
+
"""Test recommendation model"""
|
|
242
|
+
logger.info("Testing recommendation model...")
|
|
243
|
+
|
|
244
|
+
from recommendation_models import (
|
|
245
|
+
StockRecommendationModel,
|
|
246
|
+
RecommendationConfig,
|
|
247
|
+
PortfolioRecommendation,
|
|
248
|
+
RecommendationTrainer,
|
|
249
|
+
)
|
|
250
|
+
from ensemble_models import EnsembleConfig, ModelConfig
|
|
251
|
+
|
|
252
|
+
# Generate test data
|
|
253
|
+
X = generate_mock_features(300, 120)
|
|
254
|
+
y, returns, risk_labels = generate_mock_targets(300)
|
|
255
|
+
|
|
256
|
+
input_dim = X.shape[1]
|
|
257
|
+
tickers = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"]
|
|
258
|
+
|
|
259
|
+
# Create smaller ensemble for testing
|
|
260
|
+
model_configs = [
|
|
261
|
+
ModelConfig(
|
|
262
|
+
model_type="mlp",
|
|
263
|
+
hidden_dims=[128, 64],
|
|
264
|
+
dropout_rate=0.2,
|
|
265
|
+
learning_rate=0.001,
|
|
266
|
+
weight_decay=1e-4,
|
|
267
|
+
batch_size=32,
|
|
268
|
+
epochs=3,
|
|
269
|
+
),
|
|
270
|
+
ModelConfig(
|
|
271
|
+
model_type="attention",
|
|
272
|
+
hidden_dims=[64],
|
|
273
|
+
dropout_rate=0.2,
|
|
274
|
+
learning_rate=0.001,
|
|
275
|
+
weight_decay=1e-4,
|
|
276
|
+
batch_size=32,
|
|
277
|
+
epochs=3,
|
|
278
|
+
),
|
|
279
|
+
]
|
|
280
|
+
|
|
281
|
+
ensemble_config = EnsembleConfig(base_models=model_configs, ensemble_method="weighted_average")
|
|
282
|
+
|
|
283
|
+
recommendation_config = RecommendationConfig(
|
|
284
|
+
ensemble_config=ensemble_config,
|
|
285
|
+
risk_adjustment=True,
|
|
286
|
+
confidence_threshold=0.4, # Lower for testing
|
|
287
|
+
max_positions=5,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Create recommendation model
|
|
291
|
+
rec_model = StockRecommendationModel(input_dim, recommendation_config)
|
|
292
|
+
|
|
293
|
+
# Test forward pass
|
|
294
|
+
X_test = torch.FloatTensor(X.values[:10])
|
|
295
|
+
outputs = rec_model(X_test)
|
|
296
|
+
|
|
297
|
+
expected_keys = ["main_prediction", "risk_assessment", "expected_returns", "confidence"]
|
|
298
|
+
for key in expected_keys:
|
|
299
|
+
assert key in outputs, f"Missing output: {key}"
|
|
300
|
+
logger.info(f"{key} shape: {outputs[key].shape}")
|
|
301
|
+
|
|
302
|
+
# Test recommendation generation
|
|
303
|
+
recommendations = rec_model.generate_recommendations(X.values[:5], tickers, market_data=None)
|
|
304
|
+
|
|
305
|
+
logger.info(f"Generated {len(recommendations)} recommendations")
|
|
306
|
+
|
|
307
|
+
for rec in recommendations:
|
|
308
|
+
logger.info(f"Ticker: {rec.ticker}")
|
|
309
|
+
logger.info(f" Score: {rec.recommendation_score:.3f}")
|
|
310
|
+
logger.info(f" Confidence: {rec.confidence:.3f}")
|
|
311
|
+
logger.info(f" Risk: {rec.risk_level}")
|
|
312
|
+
logger.info(f" Position: {rec.position_size:.3f}")
|
|
313
|
+
logger.info(f" Reason: {rec.recommendation_reason}")
|
|
314
|
+
|
|
315
|
+
# Validate recommendation structure
|
|
316
|
+
for rec in recommendations:
|
|
317
|
+
assert isinstance(rec, PortfolioRecommendation)
|
|
318
|
+
assert 0 <= rec.recommendation_score <= 1
|
|
319
|
+
assert 0 <= rec.confidence <= 1
|
|
320
|
+
assert rec.risk_level in ["low", "medium", "high"]
|
|
321
|
+
assert 0 <= rec.position_size <= 1
|
|
322
|
+
assert isinstance(rec.key_features, list)
|
|
323
|
+
assert isinstance(rec.warnings, list)
|
|
324
|
+
|
|
325
|
+
logger.info("✅ Recommendation model test passed")
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def test_model_training():
|
|
329
|
+
"""Test model training functionality"""
|
|
330
|
+
logger.info("Testing model training...")
|
|
331
|
+
|
|
332
|
+
from recommendation_models import (
|
|
333
|
+
StockRecommendationModel,
|
|
334
|
+
RecommendationTrainer,
|
|
335
|
+
RecommendationConfig,
|
|
336
|
+
)
|
|
337
|
+
from ensemble_models import EnsembleConfig, ModelConfig, EnsembleTrainer
|
|
338
|
+
|
|
339
|
+
# Generate training data
|
|
340
|
+
X_train = generate_mock_features(200, 80)
|
|
341
|
+
X_val = generate_mock_features(50, 80)
|
|
342
|
+
|
|
343
|
+
y_train, returns_train, risk_train = generate_mock_targets(200)
|
|
344
|
+
y_val, returns_val, risk_val = generate_mock_targets(50)
|
|
345
|
+
|
|
346
|
+
input_dim = X_train.shape[1]
|
|
347
|
+
|
|
348
|
+
# Simple ensemble for faster training
|
|
349
|
+
model_configs = [
|
|
350
|
+
ModelConfig(
|
|
351
|
+
model_type="mlp",
|
|
352
|
+
hidden_dims=[64, 32],
|
|
353
|
+
dropout_rate=0.2,
|
|
354
|
+
learning_rate=0.001,
|
|
355
|
+
weight_decay=1e-4,
|
|
356
|
+
batch_size=32,
|
|
357
|
+
epochs=2,
|
|
358
|
+
)
|
|
359
|
+
]
|
|
360
|
+
|
|
361
|
+
ensemble_config = EnsembleConfig(base_models=model_configs, ensemble_method="weighted_average")
|
|
362
|
+
|
|
363
|
+
recommendation_config = RecommendationConfig(
|
|
364
|
+
ensemble_config=ensemble_config, confidence_threshold=0.3
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
# Test ensemble training
|
|
368
|
+
from ensemble_models import DeepEnsembleModel
|
|
369
|
+
|
|
370
|
+
ensemble_model = DeepEnsembleModel(input_dim, ensemble_config)
|
|
371
|
+
ensemble_trainer = EnsembleTrainer(ensemble_model, ensemble_config)
|
|
372
|
+
|
|
373
|
+
logger.info("Training ensemble model...")
|
|
374
|
+
ensemble_result = ensemble_trainer.train(X_train.values, y_train, X_val.values, y_val)
|
|
375
|
+
|
|
376
|
+
logger.info(f"Ensemble training metrics:")
|
|
377
|
+
logger.info(f" Train accuracy: {ensemble_result.train_metrics.accuracy:.3f}")
|
|
378
|
+
logger.info(f" Val accuracy: {ensemble_result.val_metrics.accuracy:.3f}")
|
|
379
|
+
|
|
380
|
+
# Test recommendation model training
|
|
381
|
+
rec_model = StockRecommendationModel(input_dim, recommendation_config)
|
|
382
|
+
rec_trainer = RecommendationTrainer(rec_model, recommendation_config)
|
|
383
|
+
|
|
384
|
+
logger.info("Training recommendation model...")
|
|
385
|
+
rec_result = rec_trainer.train(
|
|
386
|
+
X_train.values,
|
|
387
|
+
y_train,
|
|
388
|
+
returns_train,
|
|
389
|
+
risk_train,
|
|
390
|
+
X_val.values,
|
|
391
|
+
y_val,
|
|
392
|
+
returns_val,
|
|
393
|
+
risk_val,
|
|
394
|
+
epochs=5,
|
|
395
|
+
batch_size=32,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
logger.info(f"Recommendation training metrics:")
|
|
399
|
+
logger.info(f" Train accuracy: {rec_result.train_metrics.accuracy:.3f}")
|
|
400
|
+
logger.info(f" Val accuracy: {rec_result.val_metrics.accuracy:.3f}")
|
|
401
|
+
|
|
402
|
+
# Test trained model predictions
|
|
403
|
+
test_recommendations = rec_model.generate_recommendations(
|
|
404
|
+
X_val.values[:3], ["AAPL", "MSFT", "GOOGL"]
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
logger.info(f"Generated {len(test_recommendations)} test recommendations")
|
|
408
|
+
|
|
409
|
+
logger.info("✅ Model training test passed")
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def test_model_persistence():
|
|
413
|
+
"""Test model saving and loading"""
|
|
414
|
+
logger.info("Testing model persistence...")
|
|
415
|
+
|
|
416
|
+
from base_models import MLPBaseModel
|
|
417
|
+
import tempfile
|
|
418
|
+
|
|
419
|
+
# Create and test model
|
|
420
|
+
model = MLPBaseModel(input_dim=50, hidden_dims=[64, 32])
|
|
421
|
+
X_test = generate_mock_features(10, 50)
|
|
422
|
+
|
|
423
|
+
# Get initial predictions
|
|
424
|
+
original_preds = model.predict_proba(X_test)
|
|
425
|
+
|
|
426
|
+
# Save model
|
|
427
|
+
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
|
428
|
+
model_path = f.name
|
|
429
|
+
|
|
430
|
+
model.save_model(model_path)
|
|
431
|
+
|
|
432
|
+
# Create new model and load
|
|
433
|
+
new_model = MLPBaseModel(input_dim=50, hidden_dims=[64, 32])
|
|
434
|
+
new_model.load_model(model_path)
|
|
435
|
+
|
|
436
|
+
# Compare predictions
|
|
437
|
+
loaded_preds = new_model.predict_proba(X_test)
|
|
438
|
+
|
|
439
|
+
# Should be identical
|
|
440
|
+
np.testing.assert_array_almost_equal(original_preds, loaded_preds, decimal=6)
|
|
441
|
+
|
|
442
|
+
# Cleanup
|
|
443
|
+
os.unlink(model_path)
|
|
444
|
+
|
|
445
|
+
logger.info("✅ Model persistence test passed")
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def main():
|
|
449
|
+
"""Run all model tests"""
|
|
450
|
+
logger.info("Starting ensemble model tests...")
|
|
451
|
+
|
|
452
|
+
try:
|
|
453
|
+
# Test individual components
|
|
454
|
+
test_base_models()
|
|
455
|
+
test_ensemble_models()
|
|
456
|
+
test_recommendation_model()
|
|
457
|
+
test_model_training()
|
|
458
|
+
test_model_persistence()
|
|
459
|
+
|
|
460
|
+
logger.info("🎉 All ensemble model tests passed!")
|
|
461
|
+
|
|
462
|
+
# Print summary
|
|
463
|
+
logger.info("\n" + "=" * 60)
|
|
464
|
+
logger.info("PYTORCH ENSEMBLE MODEL SYSTEM SUMMARY")
|
|
465
|
+
logger.info("=" * 60)
|
|
466
|
+
logger.info("✅ Base models: MLP, ResNet with proper abstractions")
|
|
467
|
+
logger.info("✅ Ensemble models: Attention, Transformer, LSTM, CNN")
|
|
468
|
+
logger.info("✅ Deep ensemble: Weighted averaging, voting, stacking")
|
|
469
|
+
logger.info("✅ Recommendation system: Portfolio optimization")
|
|
470
|
+
logger.info("✅ Training pipeline: Multi-task learning")
|
|
471
|
+
logger.info("✅ Model persistence: Save/load functionality")
|
|
472
|
+
logger.info("✅ Comprehensive metrics: Classification, regression, trading")
|
|
473
|
+
logger.info("=" * 60)
|
|
474
|
+
|
|
475
|
+
return True
|
|
476
|
+
|
|
477
|
+
except Exception as e:
|
|
478
|
+
logger.error(f"❌ Ensemble model tests failed: {e}")
|
|
479
|
+
import traceback
|
|
480
|
+
|
|
481
|
+
traceback.print_exc()
|
|
482
|
+
return False
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
if __name__ == "__main__":
|
|
486
|
+
success = main()
|
|
487
|
+
sys.exit(0 if success else 1)
|