mcli-framework 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcli-framework might be problematic. Click here for more details.

Files changed (186) hide show
  1. mcli/app/chat_cmd.py +42 -0
  2. mcli/app/commands_cmd.py +226 -0
  3. mcli/app/completion_cmd.py +216 -0
  4. mcli/app/completion_helpers.py +288 -0
  5. mcli/app/cron_test_cmd.py +697 -0
  6. mcli/app/logs_cmd.py +419 -0
  7. mcli/app/main.py +492 -0
  8. mcli/app/model/model.py +1060 -0
  9. mcli/app/model_cmd.py +227 -0
  10. mcli/app/redis_cmd.py +269 -0
  11. mcli/app/video/video.py +1114 -0
  12. mcli/app/visual_cmd.py +303 -0
  13. mcli/chat/chat.py +2409 -0
  14. mcli/chat/command_rag.py +514 -0
  15. mcli/chat/enhanced_chat.py +652 -0
  16. mcli/chat/system_controller.py +1010 -0
  17. mcli/chat/system_integration.py +1016 -0
  18. mcli/cli.py +25 -0
  19. mcli/config.toml +20 -0
  20. mcli/lib/api/api.py +586 -0
  21. mcli/lib/api/daemon_client.py +203 -0
  22. mcli/lib/api/daemon_client_local.py +44 -0
  23. mcli/lib/api/daemon_decorator.py +217 -0
  24. mcli/lib/api/mcli_decorators.py +1032 -0
  25. mcli/lib/auth/auth.py +85 -0
  26. mcli/lib/auth/aws_manager.py +85 -0
  27. mcli/lib/auth/azure_manager.py +91 -0
  28. mcli/lib/auth/credential_manager.py +192 -0
  29. mcli/lib/auth/gcp_manager.py +93 -0
  30. mcli/lib/auth/key_manager.py +117 -0
  31. mcli/lib/auth/mcli_manager.py +93 -0
  32. mcli/lib/auth/token_manager.py +75 -0
  33. mcli/lib/auth/token_util.py +1011 -0
  34. mcli/lib/config/config.py +47 -0
  35. mcli/lib/discovery/__init__.py +1 -0
  36. mcli/lib/discovery/command_discovery.py +274 -0
  37. mcli/lib/erd/erd.py +1345 -0
  38. mcli/lib/erd/generate_graph.py +453 -0
  39. mcli/lib/files/files.py +76 -0
  40. mcli/lib/fs/fs.py +109 -0
  41. mcli/lib/lib.py +29 -0
  42. mcli/lib/logger/logger.py +611 -0
  43. mcli/lib/performance/optimizer.py +409 -0
  44. mcli/lib/performance/rust_bridge.py +502 -0
  45. mcli/lib/performance/uvloop_config.py +154 -0
  46. mcli/lib/pickles/pickles.py +50 -0
  47. mcli/lib/search/cached_vectorizer.py +479 -0
  48. mcli/lib/services/data_pipeline.py +460 -0
  49. mcli/lib/services/lsh_client.py +441 -0
  50. mcli/lib/services/redis_service.py +387 -0
  51. mcli/lib/shell/shell.py +137 -0
  52. mcli/lib/toml/toml.py +33 -0
  53. mcli/lib/ui/styling.py +47 -0
  54. mcli/lib/ui/visual_effects.py +634 -0
  55. mcli/lib/watcher/watcher.py +185 -0
  56. mcli/ml/api/app.py +215 -0
  57. mcli/ml/api/middleware.py +224 -0
  58. mcli/ml/api/routers/admin_router.py +12 -0
  59. mcli/ml/api/routers/auth_router.py +244 -0
  60. mcli/ml/api/routers/backtest_router.py +12 -0
  61. mcli/ml/api/routers/data_router.py +12 -0
  62. mcli/ml/api/routers/model_router.py +302 -0
  63. mcli/ml/api/routers/monitoring_router.py +12 -0
  64. mcli/ml/api/routers/portfolio_router.py +12 -0
  65. mcli/ml/api/routers/prediction_router.py +267 -0
  66. mcli/ml/api/routers/trade_router.py +12 -0
  67. mcli/ml/api/routers/websocket_router.py +76 -0
  68. mcli/ml/api/schemas.py +64 -0
  69. mcli/ml/auth/auth_manager.py +425 -0
  70. mcli/ml/auth/models.py +154 -0
  71. mcli/ml/auth/permissions.py +302 -0
  72. mcli/ml/backtesting/backtest_engine.py +502 -0
  73. mcli/ml/backtesting/performance_metrics.py +393 -0
  74. mcli/ml/cache.py +400 -0
  75. mcli/ml/cli/main.py +398 -0
  76. mcli/ml/config/settings.py +394 -0
  77. mcli/ml/configs/dvc_config.py +230 -0
  78. mcli/ml/configs/mlflow_config.py +131 -0
  79. mcli/ml/configs/mlops_manager.py +293 -0
  80. mcli/ml/dashboard/app.py +532 -0
  81. mcli/ml/dashboard/app_integrated.py +738 -0
  82. mcli/ml/dashboard/app_supabase.py +560 -0
  83. mcli/ml/dashboard/app_training.py +615 -0
  84. mcli/ml/dashboard/cli.py +51 -0
  85. mcli/ml/data_ingestion/api_connectors.py +501 -0
  86. mcli/ml/data_ingestion/data_pipeline.py +567 -0
  87. mcli/ml/data_ingestion/stream_processor.py +512 -0
  88. mcli/ml/database/migrations/env.py +94 -0
  89. mcli/ml/database/models.py +667 -0
  90. mcli/ml/database/session.py +200 -0
  91. mcli/ml/experimentation/ab_testing.py +845 -0
  92. mcli/ml/features/ensemble_features.py +607 -0
  93. mcli/ml/features/political_features.py +676 -0
  94. mcli/ml/features/recommendation_engine.py +809 -0
  95. mcli/ml/features/stock_features.py +573 -0
  96. mcli/ml/features/test_feature_engineering.py +346 -0
  97. mcli/ml/logging.py +85 -0
  98. mcli/ml/mlops/data_versioning.py +518 -0
  99. mcli/ml/mlops/experiment_tracker.py +377 -0
  100. mcli/ml/mlops/model_serving.py +481 -0
  101. mcli/ml/mlops/pipeline_orchestrator.py +614 -0
  102. mcli/ml/models/base_models.py +324 -0
  103. mcli/ml/models/ensemble_models.py +675 -0
  104. mcli/ml/models/recommendation_models.py +474 -0
  105. mcli/ml/models/test_models.py +487 -0
  106. mcli/ml/monitoring/drift_detection.py +676 -0
  107. mcli/ml/monitoring/metrics.py +45 -0
  108. mcli/ml/optimization/portfolio_optimizer.py +834 -0
  109. mcli/ml/preprocessing/data_cleaners.py +451 -0
  110. mcli/ml/preprocessing/feature_extractors.py +491 -0
  111. mcli/ml/preprocessing/ml_pipeline.py +382 -0
  112. mcli/ml/preprocessing/politician_trading_preprocessor.py +569 -0
  113. mcli/ml/preprocessing/test_preprocessing.py +294 -0
  114. mcli/ml/scripts/populate_sample_data.py +200 -0
  115. mcli/ml/tasks.py +400 -0
  116. mcli/ml/tests/test_integration.py +429 -0
  117. mcli/ml/tests/test_training_dashboard.py +387 -0
  118. mcli/public/oi/oi.py +15 -0
  119. mcli/public/public.py +4 -0
  120. mcli/self/self_cmd.py +1246 -0
  121. mcli/workflow/daemon/api_daemon.py +800 -0
  122. mcli/workflow/daemon/async_command_database.py +681 -0
  123. mcli/workflow/daemon/async_process_manager.py +591 -0
  124. mcli/workflow/daemon/client.py +530 -0
  125. mcli/workflow/daemon/commands.py +1196 -0
  126. mcli/workflow/daemon/daemon.py +905 -0
  127. mcli/workflow/daemon/daemon_api.py +59 -0
  128. mcli/workflow/daemon/enhanced_daemon.py +571 -0
  129. mcli/workflow/daemon/process_cli.py +244 -0
  130. mcli/workflow/daemon/process_manager.py +439 -0
  131. mcli/workflow/daemon/test_daemon.py +275 -0
  132. mcli/workflow/dashboard/dashboard_cmd.py +113 -0
  133. mcli/workflow/docker/docker.py +0 -0
  134. mcli/workflow/file/file.py +100 -0
  135. mcli/workflow/gcloud/config.toml +21 -0
  136. mcli/workflow/gcloud/gcloud.py +58 -0
  137. mcli/workflow/git_commit/ai_service.py +328 -0
  138. mcli/workflow/git_commit/commands.py +430 -0
  139. mcli/workflow/lsh_integration.py +355 -0
  140. mcli/workflow/model_service/client.py +594 -0
  141. mcli/workflow/model_service/download_and_run_efficient_models.py +288 -0
  142. mcli/workflow/model_service/lightweight_embedder.py +397 -0
  143. mcli/workflow/model_service/lightweight_model_server.py +714 -0
  144. mcli/workflow/model_service/lightweight_test.py +241 -0
  145. mcli/workflow/model_service/model_service.py +1955 -0
  146. mcli/workflow/model_service/ollama_efficient_runner.py +425 -0
  147. mcli/workflow/model_service/pdf_processor.py +386 -0
  148. mcli/workflow/model_service/test_efficient_runner.py +234 -0
  149. mcli/workflow/model_service/test_example.py +315 -0
  150. mcli/workflow/model_service/test_integration.py +131 -0
  151. mcli/workflow/model_service/test_new_features.py +149 -0
  152. mcli/workflow/openai/openai.py +99 -0
  153. mcli/workflow/politician_trading/commands.py +1790 -0
  154. mcli/workflow/politician_trading/config.py +134 -0
  155. mcli/workflow/politician_trading/connectivity.py +490 -0
  156. mcli/workflow/politician_trading/data_sources.py +395 -0
  157. mcli/workflow/politician_trading/database.py +410 -0
  158. mcli/workflow/politician_trading/demo.py +248 -0
  159. mcli/workflow/politician_trading/models.py +165 -0
  160. mcli/workflow/politician_trading/monitoring.py +413 -0
  161. mcli/workflow/politician_trading/scrapers.py +966 -0
  162. mcli/workflow/politician_trading/scrapers_california.py +412 -0
  163. mcli/workflow/politician_trading/scrapers_eu.py +377 -0
  164. mcli/workflow/politician_trading/scrapers_uk.py +350 -0
  165. mcli/workflow/politician_trading/scrapers_us_states.py +438 -0
  166. mcli/workflow/politician_trading/supabase_functions.py +354 -0
  167. mcli/workflow/politician_trading/workflow.py +852 -0
  168. mcli/workflow/registry/registry.py +180 -0
  169. mcli/workflow/repo/repo.py +223 -0
  170. mcli/workflow/scheduler/commands.py +493 -0
  171. mcli/workflow/scheduler/cron_parser.py +238 -0
  172. mcli/workflow/scheduler/job.py +182 -0
  173. mcli/workflow/scheduler/monitor.py +139 -0
  174. mcli/workflow/scheduler/persistence.py +324 -0
  175. mcli/workflow/scheduler/scheduler.py +679 -0
  176. mcli/workflow/sync/sync_cmd.py +437 -0
  177. mcli/workflow/sync/test_cmd.py +314 -0
  178. mcli/workflow/videos/videos.py +242 -0
  179. mcli/workflow/wakatime/wakatime.py +11 -0
  180. mcli/workflow/workflow.py +37 -0
  181. mcli_framework-7.0.0.dist-info/METADATA +479 -0
  182. mcli_framework-7.0.0.dist-info/RECORD +186 -0
  183. mcli_framework-7.0.0.dist-info/WHEEL +5 -0
  184. mcli_framework-7.0.0.dist-info/entry_points.txt +7 -0
  185. mcli_framework-7.0.0.dist-info/licenses/LICENSE +21 -0
  186. mcli_framework-7.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,487 @@
1
+ """Test script for ensemble models"""
2
+
3
+ import sys
4
+ import os
5
+
6
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../.."))
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ from datetime import datetime, timedelta
12
+ import logging
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def generate_mock_features(n_samples: int = 500, n_features: int = 150) -> pd.DataFrame:
20
+ """Generate mock feature data for testing"""
21
+ np.random.seed(42)
22
+
23
+ # Create realistic feature names
24
+ feature_names = []
25
+
26
+ # Technical indicators
27
+ for indicator in ["sma", "ema", "rsi", "macd", "bb", "volume", "volatility"]:
28
+ for period in [5, 10, 20, 50]:
29
+ feature_names.append(f"{indicator}_{period}")
30
+
31
+ # Political features
32
+ for pol_feature in ["influence", "trading_freq", "committee_align", "seniority"]:
33
+ for agg in ["mean", "max", "std"]:
34
+ feature_names.append(f"political_{pol_feature}_{agg}")
35
+
36
+ # Ensemble features
37
+ for i in range(50):
38
+ feature_names.append(f"ensemble_feature_{i}")
39
+
40
+ # Market regime features
41
+ for regime in ["volatility", "trend", "volume"]:
42
+ for metric in ["regime", "strength", "persistence"]:
43
+ feature_names.append(f"market_{regime}_{metric}")
44
+
45
+ # Pad or trim to exact number
46
+ while len(feature_names) < n_features:
47
+ feature_names.append(f"extra_feature_{len(feature_names)}")
48
+ feature_names = feature_names[:n_features]
49
+
50
+ # Generate correlated features that simulate real market data
51
+ features = []
52
+ for i in range(n_samples):
53
+ # Base market trend
54
+ market_trend = np.random.normal(0, 1)
55
+
56
+ # Technical features (correlated with trend)
57
+ tech_features = np.random.normal(market_trend * 0.3, 0.8, 32)
58
+
59
+ # Political features (some correlation with market)
60
+ pol_features = np.random.normal(market_trend * 0.1, 0.5, 12)
61
+
62
+ # Ensemble features (mix of correlated and noise)
63
+ ensemble_features = np.random.normal(market_trend * 0.2, 0.6, 50)
64
+
65
+ # Market regime features
66
+ regime_features = np.random.normal(market_trend * 0.4, 0.7, 9)
67
+
68
+ # Extra random features
69
+ n_extra = max(0, n_features - 103) # Ensure non-negative
70
+ if n_extra > 0:
71
+ extra_features = np.random.normal(0, 0.5, n_extra)
72
+ sample_features = np.concatenate(
73
+ [tech_features, pol_features, ensemble_features, regime_features, extra_features]
74
+ )
75
+ else:
76
+ # Truncate if we have too many features
77
+ all_features = np.concatenate([tech_features, pol_features, ensemble_features, regime_features])
78
+ sample_features = all_features[:n_features]
79
+ features.append(sample_features)
80
+
81
+ return pd.DataFrame(features, columns=feature_names)
82
+
83
+
84
+ def generate_mock_targets(n_samples: int) -> tuple:
85
+ """Generate realistic target variables"""
86
+ np.random.seed(42)
87
+
88
+ # Generate correlated targets
89
+ market_performance = np.random.normal(0, 1, n_samples)
90
+
91
+ # Binary classification target (profitable vs not)
92
+ binary_targets = (market_performance > 0).astype(int)
93
+
94
+ # Continuous returns (with realistic distribution)
95
+ returns = np.random.normal(0.05, 0.15, n_samples) # 5% mean, 15% volatility
96
+
97
+ # Risk labels (low=0, medium=1, high=2)
98
+ risk_labels = np.random.choice([0, 1, 2], n_samples, p=[0.3, 0.5, 0.2])
99
+
100
+ return binary_targets, returns, risk_labels
101
+
102
+
103
+ def test_base_models():
104
+ """Test base model functionality"""
105
+ logger.info("Testing base models...")
106
+
107
+ from base_models import MLPBaseModel, ResNetModel
108
+
109
+ # Generate test data
110
+ X = generate_mock_features(100, 50)
111
+ y, _, _ = generate_mock_targets(100)
112
+
113
+ # Test MLP model
114
+ mlp_model = MLPBaseModel(input_dim=50, hidden_dims=[128, 64], dropout_rate=0.2)
115
+
116
+ # Test forward pass
117
+ X_tensor = torch.FloatTensor(X.values)
118
+ output = mlp_model(X_tensor)
119
+ logger.info(f"MLP output shape: {output.shape}")
120
+
121
+ # Test prediction
122
+ probas = mlp_model.predict_proba(X)
123
+ predictions = mlp_model.predict(X)
124
+ logger.info(f"MLP predictions shape: {predictions.shape}")
125
+
126
+ # Test ResNet model
127
+ resnet_model = ResNetModel(input_dim=50, hidden_dim=128, num_blocks=2)
128
+
129
+ output = resnet_model(X_tensor)
130
+ logger.info(f"ResNet output shape: {output.shape}")
131
+
132
+ # Test metrics calculation
133
+ metrics = mlp_model.calculate_metrics(y, predictions)
134
+ logger.info(f"Model metrics: {metrics}")
135
+
136
+ logger.info("✅ Base models test passed")
137
+
138
+
139
+ def test_ensemble_models():
140
+ """Test ensemble model functionality"""
141
+ logger.info("Testing ensemble models...")
142
+
143
+ from ensemble_models import (
144
+ DeepEnsembleModel,
145
+ EnsembleConfig,
146
+ ModelConfig,
147
+ AttentionStockPredictor,
148
+ TransformerStockModel,
149
+ LSTMStockPredictor,
150
+ CNNFeatureExtractor,
151
+ )
152
+
153
+ # Generate test data
154
+ X = generate_mock_features(200, 100)
155
+ y, _, _ = generate_mock_targets(200)
156
+
157
+ input_dim = X.shape[1]
158
+
159
+ # Test individual models
160
+ logger.info("Testing individual ensemble components...")
161
+
162
+ # Attention model
163
+ attention_model = AttentionStockPredictor(input_dim, hidden_dim=64, num_heads=4, num_layers=2)
164
+ output = attention_model(torch.FloatTensor(X.values[:10]))
165
+ logger.info(f"Attention model output shape: {output.shape}")
166
+
167
+ # Transformer model
168
+ transformer_model = TransformerStockModel(input_dim, d_model=64, nhead=4, num_layers=2)
169
+ output = transformer_model(torch.FloatTensor(X.values[:10]))
170
+ logger.info(f"Transformer model output shape: {output.shape}")
171
+
172
+ # LSTM model
173
+ lstm_model = LSTMStockPredictor(input_dim, hidden_dim=64, num_layers=2)
174
+ output = lstm_model(torch.FloatTensor(X.values[:10]))
175
+ logger.info(f"LSTM model output shape: {output.shape}")
176
+
177
+ # CNN model
178
+ cnn_model = CNNFeatureExtractor(input_dim, num_filters=32, filter_sizes=[3, 5])
179
+ output = cnn_model(torch.FloatTensor(X.values[:10]))
180
+ logger.info(f"CNN model output shape: {output.shape}")
181
+
182
+ # Test ensemble configuration
183
+ model_configs = [
184
+ ModelConfig(
185
+ model_type="attention",
186
+ hidden_dims=[128],
187
+ dropout_rate=0.2,
188
+ learning_rate=0.001,
189
+ weight_decay=1e-4,
190
+ batch_size=32,
191
+ epochs=5,
192
+ ),
193
+ ModelConfig(
194
+ model_type="lstm",
195
+ hidden_dims=[128],
196
+ dropout_rate=0.2,
197
+ learning_rate=0.001,
198
+ weight_decay=1e-4,
199
+ batch_size=32,
200
+ epochs=5,
201
+ ),
202
+ ModelConfig(
203
+ model_type="mlp",
204
+ hidden_dims=[256, 128],
205
+ dropout_rate=0.3,
206
+ learning_rate=0.001,
207
+ weight_decay=1e-4,
208
+ batch_size=32,
209
+ epochs=5,
210
+ ),
211
+ ]
212
+
213
+ ensemble_config = EnsembleConfig(
214
+ base_models=model_configs,
215
+ ensemble_method="weighted_average",
216
+ feature_subsampling=True,
217
+ bootstrap_samples=True,
218
+ )
219
+
220
+ # Create ensemble model
221
+ ensemble_model = DeepEnsembleModel(input_dim, ensemble_config)
222
+
223
+ # Test forward pass
224
+ X_test = torch.FloatTensor(X.values[:20])
225
+ ensemble_output = ensemble_model(X_test)
226
+ logger.info(f"Ensemble output shape: {ensemble_output.shape}")
227
+
228
+ # Test individual predictions
229
+ individual_preds = ensemble_model.get_individual_predictions(X.values[:20])
230
+ logger.info(f"Individual predictions: {len(individual_preds)} models")
231
+
232
+ # Test prediction methods
233
+ ensemble_probas = ensemble_model.predict_proba(X.values[:20])
234
+ ensemble_preds = ensemble_model.predict(X.values[:20])
235
+ logger.info(f"Ensemble predictions shape: {ensemble_preds.shape}")
236
+
237
+ logger.info("✅ Ensemble models test passed")
238
+
239
+
240
+ def test_recommendation_model():
241
+ """Test recommendation model"""
242
+ logger.info("Testing recommendation model...")
243
+
244
+ from recommendation_models import (
245
+ StockRecommendationModel,
246
+ RecommendationConfig,
247
+ PortfolioRecommendation,
248
+ RecommendationTrainer,
249
+ )
250
+ from ensemble_models import EnsembleConfig, ModelConfig
251
+
252
+ # Generate test data
253
+ X = generate_mock_features(300, 120)
254
+ y, returns, risk_labels = generate_mock_targets(300)
255
+
256
+ input_dim = X.shape[1]
257
+ tickers = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"]
258
+
259
+ # Create smaller ensemble for testing
260
+ model_configs = [
261
+ ModelConfig(
262
+ model_type="mlp",
263
+ hidden_dims=[128, 64],
264
+ dropout_rate=0.2,
265
+ learning_rate=0.001,
266
+ weight_decay=1e-4,
267
+ batch_size=32,
268
+ epochs=3,
269
+ ),
270
+ ModelConfig(
271
+ model_type="attention",
272
+ hidden_dims=[64],
273
+ dropout_rate=0.2,
274
+ learning_rate=0.001,
275
+ weight_decay=1e-4,
276
+ batch_size=32,
277
+ epochs=3,
278
+ ),
279
+ ]
280
+
281
+ ensemble_config = EnsembleConfig(base_models=model_configs, ensemble_method="weighted_average")
282
+
283
+ recommendation_config = RecommendationConfig(
284
+ ensemble_config=ensemble_config,
285
+ risk_adjustment=True,
286
+ confidence_threshold=0.4, # Lower for testing
287
+ max_positions=5,
288
+ )
289
+
290
+ # Create recommendation model
291
+ rec_model = StockRecommendationModel(input_dim, recommendation_config)
292
+
293
+ # Test forward pass
294
+ X_test = torch.FloatTensor(X.values[:10])
295
+ outputs = rec_model(X_test)
296
+
297
+ expected_keys = ["main_prediction", "risk_assessment", "expected_returns", "confidence"]
298
+ for key in expected_keys:
299
+ assert key in outputs, f"Missing output: {key}"
300
+ logger.info(f"{key} shape: {outputs[key].shape}")
301
+
302
+ # Test recommendation generation
303
+ recommendations = rec_model.generate_recommendations(X.values[:5], tickers, market_data=None)
304
+
305
+ logger.info(f"Generated {len(recommendations)} recommendations")
306
+
307
+ for rec in recommendations:
308
+ logger.info(f"Ticker: {rec.ticker}")
309
+ logger.info(f" Score: {rec.recommendation_score:.3f}")
310
+ logger.info(f" Confidence: {rec.confidence:.3f}")
311
+ logger.info(f" Risk: {rec.risk_level}")
312
+ logger.info(f" Position: {rec.position_size:.3f}")
313
+ logger.info(f" Reason: {rec.recommendation_reason}")
314
+
315
+ # Validate recommendation structure
316
+ for rec in recommendations:
317
+ assert isinstance(rec, PortfolioRecommendation)
318
+ assert 0 <= rec.recommendation_score <= 1
319
+ assert 0 <= rec.confidence <= 1
320
+ assert rec.risk_level in ["low", "medium", "high"]
321
+ assert 0 <= rec.position_size <= 1
322
+ assert isinstance(rec.key_features, list)
323
+ assert isinstance(rec.warnings, list)
324
+
325
+ logger.info("✅ Recommendation model test passed")
326
+
327
+
328
+ def test_model_training():
329
+ """Test model training functionality"""
330
+ logger.info("Testing model training...")
331
+
332
+ from recommendation_models import (
333
+ StockRecommendationModel,
334
+ RecommendationTrainer,
335
+ RecommendationConfig,
336
+ )
337
+ from ensemble_models import EnsembleConfig, ModelConfig, EnsembleTrainer
338
+
339
+ # Generate training data
340
+ X_train = generate_mock_features(200, 80)
341
+ X_val = generate_mock_features(50, 80)
342
+
343
+ y_train, returns_train, risk_train = generate_mock_targets(200)
344
+ y_val, returns_val, risk_val = generate_mock_targets(50)
345
+
346
+ input_dim = X_train.shape[1]
347
+
348
+ # Simple ensemble for faster training
349
+ model_configs = [
350
+ ModelConfig(
351
+ model_type="mlp",
352
+ hidden_dims=[64, 32],
353
+ dropout_rate=0.2,
354
+ learning_rate=0.001,
355
+ weight_decay=1e-4,
356
+ batch_size=32,
357
+ epochs=2,
358
+ )
359
+ ]
360
+
361
+ ensemble_config = EnsembleConfig(base_models=model_configs, ensemble_method="weighted_average")
362
+
363
+ recommendation_config = RecommendationConfig(
364
+ ensemble_config=ensemble_config, confidence_threshold=0.3
365
+ )
366
+
367
+ # Test ensemble training
368
+ from ensemble_models import DeepEnsembleModel
369
+
370
+ ensemble_model = DeepEnsembleModel(input_dim, ensemble_config)
371
+ ensemble_trainer = EnsembleTrainer(ensemble_model, ensemble_config)
372
+
373
+ logger.info("Training ensemble model...")
374
+ ensemble_result = ensemble_trainer.train(X_train.values, y_train, X_val.values, y_val)
375
+
376
+ logger.info(f"Ensemble training metrics:")
377
+ logger.info(f" Train accuracy: {ensemble_result.train_metrics.accuracy:.3f}")
378
+ logger.info(f" Val accuracy: {ensemble_result.val_metrics.accuracy:.3f}")
379
+
380
+ # Test recommendation model training
381
+ rec_model = StockRecommendationModel(input_dim, recommendation_config)
382
+ rec_trainer = RecommendationTrainer(rec_model, recommendation_config)
383
+
384
+ logger.info("Training recommendation model...")
385
+ rec_result = rec_trainer.train(
386
+ X_train.values,
387
+ y_train,
388
+ returns_train,
389
+ risk_train,
390
+ X_val.values,
391
+ y_val,
392
+ returns_val,
393
+ risk_val,
394
+ epochs=5,
395
+ batch_size=32,
396
+ )
397
+
398
+ logger.info(f"Recommendation training metrics:")
399
+ logger.info(f" Train accuracy: {rec_result.train_metrics.accuracy:.3f}")
400
+ logger.info(f" Val accuracy: {rec_result.val_metrics.accuracy:.3f}")
401
+
402
+ # Test trained model predictions
403
+ test_recommendations = rec_model.generate_recommendations(
404
+ X_val.values[:3], ["AAPL", "MSFT", "GOOGL"]
405
+ )
406
+
407
+ logger.info(f"Generated {len(test_recommendations)} test recommendations")
408
+
409
+ logger.info("✅ Model training test passed")
410
+
411
+
412
+ def test_model_persistence():
413
+ """Test model saving and loading"""
414
+ logger.info("Testing model persistence...")
415
+
416
+ from base_models import MLPBaseModel
417
+ import tempfile
418
+
419
+ # Create and test model
420
+ model = MLPBaseModel(input_dim=50, hidden_dims=[64, 32])
421
+ X_test = generate_mock_features(10, 50)
422
+
423
+ # Get initial predictions
424
+ original_preds = model.predict_proba(X_test)
425
+
426
+ # Save model
427
+ with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
428
+ model_path = f.name
429
+
430
+ model.save_model(model_path)
431
+
432
+ # Create new model and load
433
+ new_model = MLPBaseModel(input_dim=50, hidden_dims=[64, 32])
434
+ new_model.load_model(model_path)
435
+
436
+ # Compare predictions
437
+ loaded_preds = new_model.predict_proba(X_test)
438
+
439
+ # Should be identical
440
+ np.testing.assert_array_almost_equal(original_preds, loaded_preds, decimal=6)
441
+
442
+ # Cleanup
443
+ os.unlink(model_path)
444
+
445
+ logger.info("✅ Model persistence test passed")
446
+
447
+
448
+ def main():
449
+ """Run all model tests"""
450
+ logger.info("Starting ensemble model tests...")
451
+
452
+ try:
453
+ # Test individual components
454
+ test_base_models()
455
+ test_ensemble_models()
456
+ test_recommendation_model()
457
+ test_model_training()
458
+ test_model_persistence()
459
+
460
+ logger.info("🎉 All ensemble model tests passed!")
461
+
462
+ # Print summary
463
+ logger.info("\n" + "=" * 60)
464
+ logger.info("PYTORCH ENSEMBLE MODEL SYSTEM SUMMARY")
465
+ logger.info("=" * 60)
466
+ logger.info("✅ Base models: MLP, ResNet with proper abstractions")
467
+ logger.info("✅ Ensemble models: Attention, Transformer, LSTM, CNN")
468
+ logger.info("✅ Deep ensemble: Weighted averaging, voting, stacking")
469
+ logger.info("✅ Recommendation system: Portfolio optimization")
470
+ logger.info("✅ Training pipeline: Multi-task learning")
471
+ logger.info("✅ Model persistence: Save/load functionality")
472
+ logger.info("✅ Comprehensive metrics: Classification, regression, trading")
473
+ logger.info("=" * 60)
474
+
475
+ return True
476
+
477
+ except Exception as e:
478
+ logger.error(f"❌ Ensemble model tests failed: {e}")
479
+ import traceback
480
+
481
+ traceback.print_exc()
482
+ return False
483
+
484
+
485
+ if __name__ == "__main__":
486
+ success = main()
487
+ sys.exit(0 if success else 1)