mcli-framework 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcli-framework might be problematic. Click here for more details.

Files changed (186) hide show
  1. mcli/app/chat_cmd.py +42 -0
  2. mcli/app/commands_cmd.py +226 -0
  3. mcli/app/completion_cmd.py +216 -0
  4. mcli/app/completion_helpers.py +288 -0
  5. mcli/app/cron_test_cmd.py +697 -0
  6. mcli/app/logs_cmd.py +419 -0
  7. mcli/app/main.py +492 -0
  8. mcli/app/model/model.py +1060 -0
  9. mcli/app/model_cmd.py +227 -0
  10. mcli/app/redis_cmd.py +269 -0
  11. mcli/app/video/video.py +1114 -0
  12. mcli/app/visual_cmd.py +303 -0
  13. mcli/chat/chat.py +2409 -0
  14. mcli/chat/command_rag.py +514 -0
  15. mcli/chat/enhanced_chat.py +652 -0
  16. mcli/chat/system_controller.py +1010 -0
  17. mcli/chat/system_integration.py +1016 -0
  18. mcli/cli.py +25 -0
  19. mcli/config.toml +20 -0
  20. mcli/lib/api/api.py +586 -0
  21. mcli/lib/api/daemon_client.py +203 -0
  22. mcli/lib/api/daemon_client_local.py +44 -0
  23. mcli/lib/api/daemon_decorator.py +217 -0
  24. mcli/lib/api/mcli_decorators.py +1032 -0
  25. mcli/lib/auth/auth.py +85 -0
  26. mcli/lib/auth/aws_manager.py +85 -0
  27. mcli/lib/auth/azure_manager.py +91 -0
  28. mcli/lib/auth/credential_manager.py +192 -0
  29. mcli/lib/auth/gcp_manager.py +93 -0
  30. mcli/lib/auth/key_manager.py +117 -0
  31. mcli/lib/auth/mcli_manager.py +93 -0
  32. mcli/lib/auth/token_manager.py +75 -0
  33. mcli/lib/auth/token_util.py +1011 -0
  34. mcli/lib/config/config.py +47 -0
  35. mcli/lib/discovery/__init__.py +1 -0
  36. mcli/lib/discovery/command_discovery.py +274 -0
  37. mcli/lib/erd/erd.py +1345 -0
  38. mcli/lib/erd/generate_graph.py +453 -0
  39. mcli/lib/files/files.py +76 -0
  40. mcli/lib/fs/fs.py +109 -0
  41. mcli/lib/lib.py +29 -0
  42. mcli/lib/logger/logger.py +611 -0
  43. mcli/lib/performance/optimizer.py +409 -0
  44. mcli/lib/performance/rust_bridge.py +502 -0
  45. mcli/lib/performance/uvloop_config.py +154 -0
  46. mcli/lib/pickles/pickles.py +50 -0
  47. mcli/lib/search/cached_vectorizer.py +479 -0
  48. mcli/lib/services/data_pipeline.py +460 -0
  49. mcli/lib/services/lsh_client.py +441 -0
  50. mcli/lib/services/redis_service.py +387 -0
  51. mcli/lib/shell/shell.py +137 -0
  52. mcli/lib/toml/toml.py +33 -0
  53. mcli/lib/ui/styling.py +47 -0
  54. mcli/lib/ui/visual_effects.py +634 -0
  55. mcli/lib/watcher/watcher.py +185 -0
  56. mcli/ml/api/app.py +215 -0
  57. mcli/ml/api/middleware.py +224 -0
  58. mcli/ml/api/routers/admin_router.py +12 -0
  59. mcli/ml/api/routers/auth_router.py +244 -0
  60. mcli/ml/api/routers/backtest_router.py +12 -0
  61. mcli/ml/api/routers/data_router.py +12 -0
  62. mcli/ml/api/routers/model_router.py +302 -0
  63. mcli/ml/api/routers/monitoring_router.py +12 -0
  64. mcli/ml/api/routers/portfolio_router.py +12 -0
  65. mcli/ml/api/routers/prediction_router.py +267 -0
  66. mcli/ml/api/routers/trade_router.py +12 -0
  67. mcli/ml/api/routers/websocket_router.py +76 -0
  68. mcli/ml/api/schemas.py +64 -0
  69. mcli/ml/auth/auth_manager.py +425 -0
  70. mcli/ml/auth/models.py +154 -0
  71. mcli/ml/auth/permissions.py +302 -0
  72. mcli/ml/backtesting/backtest_engine.py +502 -0
  73. mcli/ml/backtesting/performance_metrics.py +393 -0
  74. mcli/ml/cache.py +400 -0
  75. mcli/ml/cli/main.py +398 -0
  76. mcli/ml/config/settings.py +394 -0
  77. mcli/ml/configs/dvc_config.py +230 -0
  78. mcli/ml/configs/mlflow_config.py +131 -0
  79. mcli/ml/configs/mlops_manager.py +293 -0
  80. mcli/ml/dashboard/app.py +532 -0
  81. mcli/ml/dashboard/app_integrated.py +738 -0
  82. mcli/ml/dashboard/app_supabase.py +560 -0
  83. mcli/ml/dashboard/app_training.py +615 -0
  84. mcli/ml/dashboard/cli.py +51 -0
  85. mcli/ml/data_ingestion/api_connectors.py +501 -0
  86. mcli/ml/data_ingestion/data_pipeline.py +567 -0
  87. mcli/ml/data_ingestion/stream_processor.py +512 -0
  88. mcli/ml/database/migrations/env.py +94 -0
  89. mcli/ml/database/models.py +667 -0
  90. mcli/ml/database/session.py +200 -0
  91. mcli/ml/experimentation/ab_testing.py +845 -0
  92. mcli/ml/features/ensemble_features.py +607 -0
  93. mcli/ml/features/political_features.py +676 -0
  94. mcli/ml/features/recommendation_engine.py +809 -0
  95. mcli/ml/features/stock_features.py +573 -0
  96. mcli/ml/features/test_feature_engineering.py +346 -0
  97. mcli/ml/logging.py +85 -0
  98. mcli/ml/mlops/data_versioning.py +518 -0
  99. mcli/ml/mlops/experiment_tracker.py +377 -0
  100. mcli/ml/mlops/model_serving.py +481 -0
  101. mcli/ml/mlops/pipeline_orchestrator.py +614 -0
  102. mcli/ml/models/base_models.py +324 -0
  103. mcli/ml/models/ensemble_models.py +675 -0
  104. mcli/ml/models/recommendation_models.py +474 -0
  105. mcli/ml/models/test_models.py +487 -0
  106. mcli/ml/monitoring/drift_detection.py +676 -0
  107. mcli/ml/monitoring/metrics.py +45 -0
  108. mcli/ml/optimization/portfolio_optimizer.py +834 -0
  109. mcli/ml/preprocessing/data_cleaners.py +451 -0
  110. mcli/ml/preprocessing/feature_extractors.py +491 -0
  111. mcli/ml/preprocessing/ml_pipeline.py +382 -0
  112. mcli/ml/preprocessing/politician_trading_preprocessor.py +569 -0
  113. mcli/ml/preprocessing/test_preprocessing.py +294 -0
  114. mcli/ml/scripts/populate_sample_data.py +200 -0
  115. mcli/ml/tasks.py +400 -0
  116. mcli/ml/tests/test_integration.py +429 -0
  117. mcli/ml/tests/test_training_dashboard.py +387 -0
  118. mcli/public/oi/oi.py +15 -0
  119. mcli/public/public.py +4 -0
  120. mcli/self/self_cmd.py +1246 -0
  121. mcli/workflow/daemon/api_daemon.py +800 -0
  122. mcli/workflow/daemon/async_command_database.py +681 -0
  123. mcli/workflow/daemon/async_process_manager.py +591 -0
  124. mcli/workflow/daemon/client.py +530 -0
  125. mcli/workflow/daemon/commands.py +1196 -0
  126. mcli/workflow/daemon/daemon.py +905 -0
  127. mcli/workflow/daemon/daemon_api.py +59 -0
  128. mcli/workflow/daemon/enhanced_daemon.py +571 -0
  129. mcli/workflow/daemon/process_cli.py +244 -0
  130. mcli/workflow/daemon/process_manager.py +439 -0
  131. mcli/workflow/daemon/test_daemon.py +275 -0
  132. mcli/workflow/dashboard/dashboard_cmd.py +113 -0
  133. mcli/workflow/docker/docker.py +0 -0
  134. mcli/workflow/file/file.py +100 -0
  135. mcli/workflow/gcloud/config.toml +21 -0
  136. mcli/workflow/gcloud/gcloud.py +58 -0
  137. mcli/workflow/git_commit/ai_service.py +328 -0
  138. mcli/workflow/git_commit/commands.py +430 -0
  139. mcli/workflow/lsh_integration.py +355 -0
  140. mcli/workflow/model_service/client.py +594 -0
  141. mcli/workflow/model_service/download_and_run_efficient_models.py +288 -0
  142. mcli/workflow/model_service/lightweight_embedder.py +397 -0
  143. mcli/workflow/model_service/lightweight_model_server.py +714 -0
  144. mcli/workflow/model_service/lightweight_test.py +241 -0
  145. mcli/workflow/model_service/model_service.py +1955 -0
  146. mcli/workflow/model_service/ollama_efficient_runner.py +425 -0
  147. mcli/workflow/model_service/pdf_processor.py +386 -0
  148. mcli/workflow/model_service/test_efficient_runner.py +234 -0
  149. mcli/workflow/model_service/test_example.py +315 -0
  150. mcli/workflow/model_service/test_integration.py +131 -0
  151. mcli/workflow/model_service/test_new_features.py +149 -0
  152. mcli/workflow/openai/openai.py +99 -0
  153. mcli/workflow/politician_trading/commands.py +1790 -0
  154. mcli/workflow/politician_trading/config.py +134 -0
  155. mcli/workflow/politician_trading/connectivity.py +490 -0
  156. mcli/workflow/politician_trading/data_sources.py +395 -0
  157. mcli/workflow/politician_trading/database.py +410 -0
  158. mcli/workflow/politician_trading/demo.py +248 -0
  159. mcli/workflow/politician_trading/models.py +165 -0
  160. mcli/workflow/politician_trading/monitoring.py +413 -0
  161. mcli/workflow/politician_trading/scrapers.py +966 -0
  162. mcli/workflow/politician_trading/scrapers_california.py +412 -0
  163. mcli/workflow/politician_trading/scrapers_eu.py +377 -0
  164. mcli/workflow/politician_trading/scrapers_uk.py +350 -0
  165. mcli/workflow/politician_trading/scrapers_us_states.py +438 -0
  166. mcli/workflow/politician_trading/supabase_functions.py +354 -0
  167. mcli/workflow/politician_trading/workflow.py +852 -0
  168. mcli/workflow/registry/registry.py +180 -0
  169. mcli/workflow/repo/repo.py +223 -0
  170. mcli/workflow/scheduler/commands.py +493 -0
  171. mcli/workflow/scheduler/cron_parser.py +238 -0
  172. mcli/workflow/scheduler/job.py +182 -0
  173. mcli/workflow/scheduler/monitor.py +139 -0
  174. mcli/workflow/scheduler/persistence.py +324 -0
  175. mcli/workflow/scheduler/scheduler.py +679 -0
  176. mcli/workflow/sync/sync_cmd.py +437 -0
  177. mcli/workflow/sync/test_cmd.py +314 -0
  178. mcli/workflow/videos/videos.py +242 -0
  179. mcli/workflow/wakatime/wakatime.py +11 -0
  180. mcli/workflow/workflow.py +37 -0
  181. mcli_framework-7.0.0.dist-info/METADATA +479 -0
  182. mcli_framework-7.0.0.dist-info/RECORD +186 -0
  183. mcli_framework-7.0.0.dist-info/WHEEL +5 -0
  184. mcli_framework-7.0.0.dist-info/entry_points.txt +7 -0
  185. mcli_framework-7.0.0.dist-info/licenses/LICENSE +21 -0
  186. mcli_framework-7.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,474 @@
1
+ """Stock recommendation models"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ import pandas as pd
8
+ from typing import Dict, List, Optional, Tuple, Any, Union
9
+ from dataclasses import dataclass
10
+ import logging
11
+ from datetime import datetime
12
+ from base_models import BaseStockModel, ModelMetrics, ValidationResult
13
+ from ensemble_models import DeepEnsembleModel, EnsembleConfig, ModelConfig
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class RecommendationConfig:
20
+ """Configuration for recommendation model"""
21
+
22
+ ensemble_config: EnsembleConfig
23
+ risk_adjustment: bool = True
24
+ confidence_threshold: float = 0.6
25
+ diversification_penalty: float = 0.1
26
+ sector_weights: Optional[Dict[str, float]] = None
27
+ max_positions: int = 20
28
+ rebalance_frequency: str = "weekly" # daily, weekly, monthly
29
+
30
+
31
+ @dataclass
32
+ class PortfolioRecommendation:
33
+ """Portfolio recommendation result"""
34
+
35
+ ticker: str
36
+ recommendation_score: float
37
+ confidence: float
38
+ risk_level: str
39
+ expected_return: float
40
+ risk_adjusted_score: float
41
+ position_size: float
42
+ entry_price: Optional[float] = None
43
+ target_price: Optional[float] = None
44
+ stop_loss: Optional[float] = None
45
+ recommendation_reason: str = ""
46
+ key_features: List[str] = None
47
+ warnings: List[str] = None
48
+ timestamp: datetime = None
49
+
50
+ def __post_init__(self):
51
+ if self.key_features is None:
52
+ self.key_features = []
53
+ if self.warnings is None:
54
+ self.warnings = []
55
+ if self.timestamp is None:
56
+ self.timestamp = datetime.now()
57
+
58
+
59
+ class StockRecommendationModel(BaseStockModel):
60
+ """Main stock recommendation model combining ensemble prediction with portfolio optimization"""
61
+
62
+ def __init__(self, input_dim: int, config: RecommendationConfig):
63
+ super().__init__(input_dim, config.__dict__)
64
+ self.recommendation_config = config
65
+
66
+ # Core ensemble model
67
+ self.ensemble_model = DeepEnsembleModel(input_dim, config.ensemble_config)
68
+
69
+ # Risk assessment network
70
+ self.risk_network = nn.Sequential(
71
+ nn.Linear(input_dim, 256),
72
+ nn.ReLU(),
73
+ nn.Dropout(0.2),
74
+ nn.Linear(256, 128),
75
+ nn.ReLU(),
76
+ nn.Dropout(0.2),
77
+ nn.Linear(128, 3), # low, medium, high risk
78
+ )
79
+
80
+ # Expected return regression network
81
+ self.return_network = nn.Sequential(
82
+ nn.Linear(input_dim, 256),
83
+ nn.ReLU(),
84
+ nn.Dropout(0.2),
85
+ nn.Linear(256, 128),
86
+ nn.ReLU(),
87
+ nn.Dropout(0.2),
88
+ nn.Linear(128, 1), # Expected return
89
+ )
90
+
91
+ # Confidence estimation network
92
+ self.confidence_network = nn.Sequential(
93
+ nn.Linear(input_dim + 2, 128), # +2 for prediction and risk
94
+ nn.ReLU(),
95
+ nn.Dropout(0.1),
96
+ nn.Linear(128, 64),
97
+ nn.ReLU(),
98
+ nn.Linear(64, 1),
99
+ nn.Sigmoid(), # Confidence between 0 and 1
100
+ )
101
+
102
+ self.softmax = nn.Softmax(dim=1)
103
+
104
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
105
+ """Forward pass returning multiple outputs"""
106
+ # Main prediction
107
+ main_prediction = self.ensemble_model(x)
108
+
109
+ # Risk assessment
110
+ risk_logits = self.risk_network(x)
111
+ risk_probs = self.softmax(risk_logits)
112
+
113
+ # Expected return
114
+ expected_returns = self.return_network(x)
115
+
116
+ # Confidence estimation
117
+ # Combine main prediction probabilities with risk assessment
118
+ main_probs = self.softmax(main_prediction)
119
+ max_prob = torch.max(main_probs, dim=1, keepdim=True)[0]
120
+ risk_entropy = -torch.sum(risk_probs * torch.log(risk_probs + 1e-8), dim=1, keepdim=True)
121
+
122
+ confidence_input = torch.cat([x, max_prob, risk_entropy], dim=1)
123
+ confidence = self.confidence_network(confidence_input)
124
+
125
+ return {
126
+ "main_prediction": main_prediction,
127
+ "risk_assessment": risk_probs,
128
+ "expected_returns": expected_returns,
129
+ "confidence": confidence,
130
+ }
131
+
132
+ def predict_proba(self, X: Union[torch.Tensor, np.ndarray, pd.DataFrame]) -> np.ndarray:
133
+ """Predict class probabilities"""
134
+ self.eval()
135
+ with torch.no_grad():
136
+ X_tensor = self.preprocess_input(X)
137
+ outputs = self.forward(X_tensor)
138
+ probas = self.softmax(outputs["main_prediction"])
139
+ return probas.cpu().numpy()
140
+
141
+ def generate_recommendations(
142
+ self,
143
+ X: Union[torch.Tensor, np.ndarray, pd.DataFrame],
144
+ tickers: List[str],
145
+ market_data: Optional[pd.DataFrame] = None,
146
+ ) -> List[PortfolioRecommendation]:
147
+ """Generate portfolio recommendations"""
148
+ self.eval()
149
+ recommendations = []
150
+
151
+ with torch.no_grad():
152
+ X_tensor = self.preprocess_input(X)
153
+ outputs = self.forward(X_tensor)
154
+
155
+ # Extract predictions
156
+ main_probs = self.softmax(outputs["main_prediction"]).cpu().numpy()
157
+ risk_probs = outputs["risk_assessment"].cpu().numpy()
158
+ expected_returns = outputs["expected_returns"].cpu().numpy().flatten()
159
+ confidences = outputs["confidence"].cpu().numpy().flatten()
160
+
161
+ for i, ticker in enumerate(tickers):
162
+ rec = self._create_recommendation(
163
+ ticker,
164
+ main_probs[i],
165
+ risk_probs[i],
166
+ expected_returns[i],
167
+ confidences[i],
168
+ X_tensor[i].cpu().numpy(),
169
+ market_data,
170
+ )
171
+ recommendations.append(rec)
172
+
173
+ # Apply portfolio-level optimization
174
+ recommendations = self._optimize_portfolio(recommendations)
175
+
176
+ return recommendations
177
+
178
+ def _create_recommendation(
179
+ self,
180
+ ticker: str,
181
+ main_prob: np.ndarray,
182
+ risk_prob: np.ndarray,
183
+ expected_return: float,
184
+ confidence: float,
185
+ features: np.ndarray,
186
+ market_data: Optional[pd.DataFrame],
187
+ ) -> PortfolioRecommendation:
188
+ """Create individual stock recommendation"""
189
+
190
+ # Basic recommendation score (probability of positive outcome)
191
+ recommendation_score = main_prob[1] # Assuming class 1 is positive
192
+
193
+ # Risk level determination
194
+ risk_levels = ["low", "medium", "high"]
195
+ risk_level = risk_levels[np.argmax(risk_prob)]
196
+
197
+ # Risk-adjusted score
198
+ risk_penalty = {"low": 0.0, "medium": 0.1, "high": 0.2}
199
+ risk_adjusted_score = recommendation_score * (1 - risk_penalty[risk_level])
200
+
201
+ # Position sizing based on confidence and risk
202
+ base_position = 0.05 # 5% base position
203
+ confidence_multiplier = confidence * 2 # Scale confidence
204
+ risk_multiplier = {"low": 1.0, "medium": 0.8, "high": 0.6}
205
+
206
+ position_size = base_position * confidence_multiplier * risk_multiplier[risk_level]
207
+ position_size = min(position_size, 0.15) # Max 15% position
208
+
209
+ # Price targets (simplified - would use more sophisticated models in practice)
210
+ entry_price = None
211
+ target_price = None
212
+ stop_loss = None
213
+
214
+ if market_data is not None and ticker in market_data["symbol"].values:
215
+ ticker_data = market_data[market_data["symbol"] == ticker].iloc[-1]
216
+ current_price = ticker_data["close"]
217
+
218
+ entry_price = current_price
219
+ target_price = current_price * (1 + expected_return * 0.5) # Conservative target
220
+ stop_loss = current_price * (
221
+ 1 - 0.1 * (1 + risk_penalty[risk_level])
222
+ ) # Dynamic stop loss
223
+
224
+ # Generate explanation
225
+ reason = self._generate_recommendation_reason(
226
+ recommendation_score, risk_level, confidence, expected_return
227
+ )
228
+
229
+ # Key features (simplified - would extract from feature importance)
230
+ key_features = self._extract_key_features(features)
231
+
232
+ # Warnings
233
+ warnings = self._generate_warnings(risk_level, confidence, recommendation_score)
234
+
235
+ return PortfolioRecommendation(
236
+ ticker=ticker,
237
+ recommendation_score=recommendation_score,
238
+ confidence=confidence,
239
+ risk_level=risk_level,
240
+ expected_return=expected_return,
241
+ risk_adjusted_score=risk_adjusted_score,
242
+ position_size=position_size,
243
+ entry_price=entry_price,
244
+ target_price=target_price,
245
+ stop_loss=stop_loss,
246
+ recommendation_reason=reason,
247
+ key_features=key_features,
248
+ warnings=warnings,
249
+ )
250
+
251
+ def _generate_recommendation_reason(
252
+ self, score: float, risk: str, confidence: float, expected_return: float
253
+ ) -> str:
254
+ """Generate human-readable recommendation reason"""
255
+ if score > 0.7:
256
+ strength = "Strong"
257
+ elif score > 0.6:
258
+ strength = "Moderate"
259
+ else:
260
+ strength = "Weak"
261
+
262
+ return (
263
+ f"{strength} recommendation based on {confidence:.1%} confidence. "
264
+ f"Expected return: {expected_return:.1%}, Risk level: {risk}."
265
+ )
266
+
267
+ def _extract_key_features(self, features: np.ndarray) -> List[str]:
268
+ """Extract key features driving the recommendation"""
269
+ # Simplified implementation - would use feature importance in practice
270
+ return ["technical_indicators", "political_influence", "market_regime"]
271
+
272
+ def _generate_warnings(self, risk_level: str, confidence: float, score: float) -> List[str]:
273
+ """Generate warnings for the recommendation"""
274
+ warnings = []
275
+
276
+ if confidence < 0.5:
277
+ warnings.append("Low confidence prediction")
278
+
279
+ if risk_level == "high":
280
+ warnings.append("High risk investment")
281
+
282
+ if score < 0.55:
283
+ warnings.append("Weak recommendation signal")
284
+
285
+ return warnings
286
+
287
+ def _optimize_portfolio(
288
+ self, recommendations: List[PortfolioRecommendation]
289
+ ) -> List[PortfolioRecommendation]:
290
+ """Apply portfolio-level optimization"""
291
+ # Sort by risk-adjusted score
292
+ recommendations.sort(key=lambda x: x.risk_adjusted_score, reverse=True)
293
+
294
+ # Apply position limits
295
+ total_position = 0.0
296
+ max_positions = self.recommendation_config.max_positions
297
+
298
+ optimized_recommendations = []
299
+ for i, rec in enumerate(recommendations):
300
+ if i >= max_positions:
301
+ break
302
+
303
+ # Adjust position size based on portfolio allocation
304
+ if total_position + rec.position_size > 1.0:
305
+ rec.position_size = max(0, 1.0 - total_position)
306
+
307
+ total_position += rec.position_size
308
+
309
+ # Only include if meets confidence threshold
310
+ if rec.confidence >= self.recommendation_config.confidence_threshold:
311
+ optimized_recommendations.append(rec)
312
+
313
+ if total_position >= 1.0:
314
+ break
315
+
316
+ return optimized_recommendations
317
+
318
+
319
+ class RecommendationTrainer:
320
+ """Trainer for recommendation model"""
321
+
322
+ def __init__(self, model: StockRecommendationModel, config: RecommendationConfig):
323
+ self.model = model
324
+ self.config = config
325
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
326
+ self.model.to(self.device)
327
+
328
+ def train(
329
+ self,
330
+ X_train: np.ndarray,
331
+ y_train: np.ndarray,
332
+ returns_train: np.ndarray,
333
+ risk_labels_train: np.ndarray,
334
+ X_val: Optional[np.ndarray] = None,
335
+ y_val: Optional[np.ndarray] = None,
336
+ returns_val: Optional[np.ndarray] = None,
337
+ risk_labels_val: Optional[np.ndarray] = None,
338
+ epochs: int = 100,
339
+ batch_size: int = 64,
340
+ ) -> ValidationResult:
341
+ """Train the recommendation model"""
342
+
343
+ from torch.utils.data import DataLoader, TensorDataset
344
+
345
+ logger.info("Training recommendation model...")
346
+
347
+ # Convert to tensors
348
+ X_tensor = torch.FloatTensor(X_train).to(self.device)
349
+ y_tensor = torch.LongTensor(y_train).to(self.device)
350
+ returns_tensor = torch.FloatTensor(returns_train).to(self.device)
351
+ risk_tensor = torch.LongTensor(risk_labels_train).to(self.device)
352
+
353
+ # Create data loader
354
+ dataset = TensorDataset(X_tensor, y_tensor, returns_tensor, risk_tensor)
355
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
356
+
357
+ # Setup optimizers
358
+ ensemble_params = list(self.model.ensemble_model.parameters())
359
+ other_params = (
360
+ list(self.model.risk_network.parameters())
361
+ + list(self.model.return_network.parameters())
362
+ + list(self.model.confidence_network.parameters())
363
+ )
364
+
365
+ ensemble_optimizer = torch.optim.Adam(ensemble_params, lr=0.001, weight_decay=1e-4)
366
+ other_optimizer = torch.optim.Adam(other_params, lr=0.0005, weight_decay=1e-4)
367
+
368
+ # Loss functions
369
+ classification_loss = nn.CrossEntropyLoss()
370
+ regression_loss = nn.MSELoss()
371
+ confidence_loss = nn.BCELoss()
372
+
373
+ # Training loop
374
+ train_losses = []
375
+ val_losses = []
376
+
377
+ for epoch in range(epochs):
378
+ self.model.train()
379
+ epoch_loss = 0.0
380
+
381
+ for batch_X, batch_y, batch_returns, batch_risk in loader:
382
+ # Zero gradients
383
+ ensemble_optimizer.zero_grad()
384
+ other_optimizer.zero_grad()
385
+
386
+ # Forward pass
387
+ outputs = self.model(batch_X)
388
+
389
+ # Calculate losses
390
+ main_loss = classification_loss(outputs["main_prediction"], batch_y)
391
+ risk_loss = classification_loss(outputs["risk_assessment"], batch_risk)
392
+ return_loss = regression_loss(outputs["expected_returns"].squeeze(), batch_returns)
393
+
394
+ # Confidence loss (higher confidence for correct predictions)
395
+ main_probs = torch.softmax(outputs["main_prediction"], dim=1)
396
+ correct_probs = main_probs.gather(1, batch_y.unsqueeze(1)).squeeze()
397
+ target_confidence = (correct_probs > 0.5).float()
398
+ conf_loss = confidence_loss(outputs["confidence"].squeeze(), target_confidence)
399
+
400
+ # Combined loss
401
+ total_loss = main_loss + 0.5 * risk_loss + 0.3 * return_loss + 0.2 * conf_loss
402
+
403
+ # Backward pass
404
+ total_loss.backward()
405
+
406
+ # Update parameters
407
+ ensemble_optimizer.step()
408
+ other_optimizer.step()
409
+
410
+ epoch_loss += total_loss.item()
411
+
412
+ avg_loss = epoch_loss / len(loader)
413
+ train_losses.append(avg_loss)
414
+
415
+ # Validation
416
+ if X_val is not None:
417
+ val_loss = self._validate(X_val, y_val, returns_val, risk_labels_val)
418
+ val_losses.append(val_loss)
419
+
420
+ if epoch % 10 == 0:
421
+ val_str = f", Val Loss: {val_loss:.4f}" if X_val is not None else ""
422
+ logger.info(f"Epoch {epoch}/{epochs}, Train Loss: {avg_loss:.4f}{val_str}")
423
+
424
+ # Final evaluation
425
+ train_metrics = self._evaluate(X_train, y_train)
426
+ val_metrics = self._evaluate(X_val, y_val) if X_val is not None else None
427
+
428
+ self.model.is_trained = True
429
+
430
+ return ValidationResult(
431
+ train_metrics=train_metrics,
432
+ val_metrics=val_metrics,
433
+ training_history={"train_losses": train_losses, "val_losses": val_losses},
434
+ )
435
+
436
+ def _validate(
437
+ self,
438
+ X_val: np.ndarray,
439
+ y_val: np.ndarray,
440
+ returns_val: np.ndarray,
441
+ risk_labels_val: np.ndarray,
442
+ ) -> float:
443
+ """Validate model during training"""
444
+ self.model.eval()
445
+
446
+ with torch.no_grad():
447
+ X_tensor = torch.FloatTensor(X_val).to(self.device)
448
+ y_tensor = torch.LongTensor(y_val).to(self.device)
449
+ returns_tensor = torch.FloatTensor(returns_val).to(self.device)
450
+ risk_tensor = torch.LongTensor(risk_labels_val).to(self.device)
451
+
452
+ outputs = self.model(X_tensor)
453
+
454
+ # Calculate validation loss
455
+ classification_loss = nn.CrossEntropyLoss()
456
+ regression_loss = nn.MSELoss()
457
+
458
+ main_loss = classification_loss(outputs["main_prediction"], y_tensor)
459
+ risk_loss = classification_loss(outputs["risk_assessment"], risk_tensor)
460
+ return_loss = regression_loss(outputs["expected_returns"].squeeze(), returns_tensor)
461
+
462
+ total_loss = main_loss + 0.5 * risk_loss + 0.3 * return_loss
463
+
464
+ return total_loss.item()
465
+
466
+ def _evaluate(self, X: np.ndarray, y: np.ndarray) -> ModelMetrics:
467
+ """Evaluate model performance"""
468
+ if X is None or y is None:
469
+ return ModelMetrics(0, 0, 0, 0, 0)
470
+
471
+ predictions = self.model.predict(X)
472
+ probabilities = self.model.predict_proba(X)
473
+
474
+ return self.model.calculate_metrics(y, predictions, probabilities)