mcli-framework 7.6.0__py3-none-any.whl → 7.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcli-framework might be problematic. Click here for more details.

Files changed (49) hide show
  1. mcli/app/commands_cmd.py +51 -39
  2. mcli/app/main.py +10 -2
  3. mcli/app/model_cmd.py +1 -1
  4. mcli/lib/custom_commands.py +4 -10
  5. mcli/ml/api/app.py +1 -5
  6. mcli/ml/dashboard/app.py +2 -2
  7. mcli/ml/dashboard/app_integrated.py +168 -116
  8. mcli/ml/dashboard/app_supabase.py +7 -3
  9. mcli/ml/dashboard/app_training.py +3 -6
  10. mcli/ml/dashboard/components/charts.py +74 -115
  11. mcli/ml/dashboard/components/metrics.py +24 -44
  12. mcli/ml/dashboard/components/tables.py +32 -40
  13. mcli/ml/dashboard/overview.py +102 -78
  14. mcli/ml/dashboard/pages/cicd.py +103 -56
  15. mcli/ml/dashboard/pages/debug_dependencies.py +35 -28
  16. mcli/ml/dashboard/pages/gravity_viz.py +374 -313
  17. mcli/ml/dashboard/pages/monte_carlo_predictions.py +50 -48
  18. mcli/ml/dashboard/pages/predictions_enhanced.py +396 -248
  19. mcli/ml/dashboard/pages/scrapers_and_logs.py +299 -273
  20. mcli/ml/dashboard/pages/test_portfolio.py +153 -121
  21. mcli/ml/dashboard/pages/trading.py +238 -169
  22. mcli/ml/dashboard/pages/workflows.py +129 -84
  23. mcli/ml/dashboard/streamlit_extras_utils.py +70 -79
  24. mcli/ml/dashboard/utils.py +24 -21
  25. mcli/ml/dashboard/warning_suppression.py +6 -4
  26. mcli/ml/database/session.py +16 -5
  27. mcli/ml/mlops/pipeline_orchestrator.py +1 -3
  28. mcli/ml/predictions/monte_carlo.py +6 -18
  29. mcli/ml/trading/alpaca_client.py +95 -96
  30. mcli/ml/trading/migrations.py +76 -40
  31. mcli/ml/trading/models.py +78 -60
  32. mcli/ml/trading/paper_trading.py +92 -74
  33. mcli/ml/trading/risk_management.py +106 -85
  34. mcli/ml/trading/trading_service.py +155 -110
  35. mcli/ml/training/train_model.py +1 -3
  36. mcli/self/self_cmd.py +71 -57
  37. mcli/workflow/daemon/daemon.py +2 -0
  38. mcli/workflow/model_service/openai_adapter.py +6 -2
  39. mcli/workflow/politician_trading/models.py +6 -2
  40. mcli/workflow/politician_trading/scrapers_corporate_registry.py +39 -88
  41. mcli/workflow/politician_trading/scrapers_free_sources.py +32 -39
  42. mcli/workflow/politician_trading/scrapers_third_party.py +21 -39
  43. mcli/workflow/politician_trading/seed_database.py +70 -89
  44. {mcli_framework-7.6.0.dist-info → mcli_framework-7.6.1.dist-info}/METADATA +1 -1
  45. {mcli_framework-7.6.0.dist-info → mcli_framework-7.6.1.dist-info}/RECORD +49 -49
  46. {mcli_framework-7.6.0.dist-info → mcli_framework-7.6.1.dist-info}/WHEEL +0 -0
  47. {mcli_framework-7.6.0.dist-info → mcli_framework-7.6.1.dist-info}/entry_points.txt +0 -0
  48. {mcli_framework-7.6.0.dist-info → mcli_framework-7.6.1.dist-info}/licenses/LICENSE +0 -0
  49. {mcli_framework-7.6.0.dist-info → mcli_framework-7.6.1.dist-info}/top_level.txt +0 -0
@@ -7,15 +7,29 @@ from typing import Dict, List, Optional, Tuple, Union
7
7
  from uuid import UUID
8
8
 
9
9
  import pandas as pd
10
- from sqlalchemy.orm import Session
11
10
  from sqlalchemy import and_, desc, func
11
+ from sqlalchemy.orm import Session
12
12
 
13
13
  from mcli.ml.trading.alpaca_client import AlpacaTradingClient, create_trading_client
14
14
  from mcli.ml.trading.models import (
15
- TradingAccount, Portfolio, Position, TradingOrder, PortfolioPerformanceSnapshot,
16
- TradingSignal, OrderStatus, OrderType, OrderSide, PositionSide, PortfolioType,
17
- TradingAccountCreate, PortfolioCreate, OrderCreate, PortfolioResponse,
18
- PositionResponse, OrderResponse, TradingSignalResponse
15
+ OrderCreate,
16
+ OrderResponse,
17
+ OrderSide,
18
+ OrderStatus,
19
+ OrderType,
20
+ Portfolio,
21
+ PortfolioCreate,
22
+ PortfolioPerformanceSnapshot,
23
+ PortfolioResponse,
24
+ PortfolioType,
25
+ Position,
26
+ PositionResponse,
27
+ PositionSide,
28
+ TradingAccount,
29
+ TradingAccountCreate,
30
+ TradingOrder,
31
+ TradingSignal,
32
+ TradingSignalResponse,
19
33
  )
20
34
 
21
35
  logger = logging.getLogger(__name__)
@@ -23,11 +37,13 @@ logger = logging.getLogger(__name__)
23
37
 
24
38
  class TradingService:
25
39
  """Service for managing trading operations"""
26
-
40
+
27
41
  def __init__(self, db_session: Session):
28
42
  self.db = db_session
29
-
30
- def create_trading_account(self, user_id: UUID, account_data: TradingAccountCreate) -> TradingAccount:
43
+
44
+ def create_trading_account(
45
+ self, user_id: UUID, account_data: TradingAccountCreate
46
+ ) -> TradingAccount:
31
47
  """Create a new trading account"""
32
48
  try:
33
49
  account = TradingAccount(
@@ -41,26 +57,27 @@ class TradingService:
41
57
  max_position_size=account_data.max_position_size,
42
58
  max_portfolio_risk=account_data.max_portfolio_risk,
43
59
  )
44
-
60
+
45
61
  self.db.add(account)
46
62
  self.db.commit()
47
63
  self.db.refresh(account)
48
-
64
+
49
65
  logger.info(f"Created trading account {account.id} for user {user_id}")
50
66
  return account
51
-
67
+
52
68
  except Exception as e:
53
69
  self.db.rollback()
54
70
  logger.error(f"Failed to create trading account: {e}")
55
71
  raise
56
-
72
+
57
73
  def get_trading_account(self, account_id: UUID) -> Optional[TradingAccount]:
58
74
  """Get trading account by ID"""
59
- return self.db.query(TradingAccount).filter(
60
- TradingAccount.id == account_id,
61
- TradingAccount.is_active == True
62
- ).first()
63
-
75
+ return (
76
+ self.db.query(TradingAccount)
77
+ .filter(TradingAccount.id == account_id, TradingAccount.is_active == True)
78
+ .first()
79
+ )
80
+
64
81
  def create_portfolio(self, account_id: UUID, portfolio_data: PortfolioCreate) -> Portfolio:
65
82
  """Create a new portfolio"""
66
83
  try:
@@ -72,86 +89,89 @@ class TradingService:
72
89
  current_value=Decimal(str(portfolio_data.initial_capital)),
73
90
  cash_balance=Decimal(str(portfolio_data.initial_capital)),
74
91
  )
75
-
92
+
76
93
  self.db.add(portfolio)
77
94
  self.db.commit()
78
95
  self.db.refresh(portfolio)
79
-
96
+
80
97
  logger.info(f"Created portfolio {portfolio.id} for account {account_id}")
81
98
  return portfolio
82
-
99
+
83
100
  except Exception as e:
84
101
  self.db.rollback()
85
102
  logger.error(f"Failed to create portfolio: {e}")
86
103
  raise
87
-
104
+
88
105
  def get_portfolio(self, portfolio_id: UUID) -> Optional[Portfolio]:
89
106
  """Get portfolio by ID"""
90
- return self.db.query(Portfolio).filter(
91
- Portfolio.id == portfolio_id,
92
- Portfolio.is_active == True
93
- ).first()
94
-
107
+ return (
108
+ self.db.query(Portfolio)
109
+ .filter(Portfolio.id == portfolio_id, Portfolio.is_active == True)
110
+ .first()
111
+ )
112
+
95
113
  def get_user_portfolios(self, user_id: UUID) -> List[Portfolio]:
96
114
  """Get all portfolios for a user"""
97
- return self.db.query(Portfolio).join(TradingAccount).filter(
98
- TradingAccount.user_id == user_id,
99
- Portfolio.is_active == True
100
- ).all()
101
-
115
+ return (
116
+ self.db.query(Portfolio)
117
+ .join(TradingAccount)
118
+ .filter(TradingAccount.user_id == user_id, Portfolio.is_active == True)
119
+ .all()
120
+ )
121
+
102
122
  def create_alpaca_client(self, account: TradingAccount) -> AlpacaTradingClient:
103
123
  """Create Alpaca client for trading account"""
104
124
  if not account.alpaca_api_key or not account.alpaca_secret_key:
105
125
  raise ValueError("Alpaca credentials not configured for this account")
106
-
126
+
107
127
  return create_trading_client(
108
128
  api_key=account.alpaca_api_key,
109
129
  secret_key=account.alpaca_secret_key,
110
- paper_trading=account.paper_trading
130
+ paper_trading=account.paper_trading,
111
131
  )
112
-
132
+
113
133
  def sync_portfolio_with_alpaca(self, portfolio: Portfolio) -> bool:
114
134
  """Sync portfolio data with Alpaca"""
115
135
  try:
116
136
  account = self.get_trading_account(portfolio.trading_account_id)
117
137
  if not account:
118
138
  return False
119
-
139
+
120
140
  alpaca_client = self.create_alpaca_client(account)
121
141
  alpaca_portfolio = alpaca_client.get_portfolio()
122
-
142
+
123
143
  # Update portfolio values
124
144
  portfolio.current_value = Decimal(str(alpaca_portfolio.portfolio_value))
125
145
  portfolio.cash_balance = Decimal(str(alpaca_portfolio.cash))
126
146
  portfolio.unrealized_pl = Decimal(str(alpaca_portfolio.unrealized_pl))
127
147
  portfolio.realized_pl = Decimal(str(alpaca_portfolio.realized_pl))
128
-
148
+
129
149
  # Calculate returns
130
150
  if portfolio.initial_capital > 0:
131
151
  total_return = portfolio.current_value - portfolio.initial_capital
132
152
  portfolio.total_return = float(total_return)
133
153
  portfolio.total_return_pct = float(total_return / portfolio.initial_capital * 100)
134
-
154
+
135
155
  # Update positions
136
156
  self._sync_positions(portfolio, alpaca_portfolio.positions)
137
-
157
+
138
158
  # Create performance snapshot
139
159
  self._create_performance_snapshot(portfolio)
140
-
160
+
141
161
  self.db.commit()
142
162
  logger.info(f"Synced portfolio {portfolio.id} with Alpaca")
143
163
  return True
144
-
164
+
145
165
  except Exception as e:
146
166
  self.db.rollback()
147
167
  logger.error(f"Failed to sync portfolio with Alpaca: {e}")
148
168
  return False
149
-
169
+
150
170
  def _sync_positions(self, portfolio: Portfolio, alpaca_positions: List):
151
171
  """Sync positions with Alpaca data"""
152
172
  # Clear existing positions
153
173
  self.db.query(Position).filter(Position.portfolio_id == portfolio.id).delete()
154
-
174
+
155
175
  # Add new positions
156
176
  for alpaca_pos in alpaca_positions:
157
177
  position = Position(
@@ -169,7 +189,7 @@ class TradingService:
169
189
  weight=float(alpaca_pos.market_value / portfolio.current_value),
170
190
  )
171
191
  self.db.add(position)
172
-
192
+
173
193
  def _create_performance_snapshot(self, portfolio: Portfolio):
174
194
  """Create daily performance snapshot"""
175
195
  snapshot = PortfolioPerformanceSnapshot(
@@ -187,7 +207,7 @@ class TradingService:
187
207
  positions_data=self._get_positions_data(portfolio.id),
188
208
  )
189
209
  self.db.add(snapshot)
190
-
210
+
191
211
  def _get_positions_data(self, portfolio_id: UUID) -> Dict:
192
212
  """Get positions data for snapshot"""
193
213
  positions = self.db.query(Position).filter(Position.portfolio_id == portfolio_id).all()
@@ -203,33 +223,36 @@ class TradingService:
203
223
  }
204
224
  for pos in positions
205
225
  }
206
-
207
- def place_order(self, portfolio_id: UUID, order_data: OrderCreate, check_risk: bool = True) -> TradingOrder:
226
+
227
+ def place_order(
228
+ self, portfolio_id: UUID, order_data: OrderCreate, check_risk: bool = True
229
+ ) -> TradingOrder:
208
230
  """Place a trading order"""
209
231
  try:
210
232
  portfolio = self.get_portfolio(portfolio_id)
211
233
  if not portfolio:
212
234
  raise ValueError("Portfolio not found")
213
-
235
+
214
236
  account = self.get_trading_account(portfolio.trading_account_id)
215
237
  if not account:
216
238
  raise ValueError("Trading account not found")
217
-
239
+
218
240
  # Check risk limits if enabled
219
241
  if check_risk:
220
242
  from mcli.ml.trading.risk_management import RiskManager
243
+
221
244
  risk_manager = RiskManager(self)
222
-
245
+
223
246
  order_dict = {
224
247
  "symbol": order_data.symbol,
225
248
  "quantity": order_data.quantity,
226
249
  "side": order_data.side.value,
227
250
  }
228
-
251
+
229
252
  risk_ok, warnings = risk_manager.check_risk_limits(portfolio_id, order_dict)
230
253
  if not risk_ok:
231
254
  raise ValueError(f"Order violates risk limits: {'; '.join(warnings)}")
232
-
255
+
233
256
  # Create order record
234
257
  order = TradingOrder(
235
258
  trading_account_id=account.id,
@@ -238,26 +261,28 @@ class TradingService:
238
261
  side=order_data.side,
239
262
  order_type=order_data.order_type,
240
263
  quantity=order_data.quantity,
241
- limit_price=Decimal(str(order_data.limit_price)) if order_data.limit_price else None,
264
+ limit_price=(
265
+ Decimal(str(order_data.limit_price)) if order_data.limit_price else None
266
+ ),
242
267
  stop_price=Decimal(str(order_data.stop_price)) if order_data.stop_price else None,
243
268
  remaining_quantity=order_data.quantity,
244
269
  time_in_force=order_data.time_in_force,
245
270
  extended_hours=order_data.extended_hours,
246
271
  )
247
-
272
+
248
273
  self.db.add(order)
249
274
  self.db.flush() # Get the ID
250
-
275
+
251
276
  # Place order with Alpaca if account has credentials
252
277
  if account.alpaca_api_key and account.alpaca_secret_key:
253
278
  alpaca_client = self.create_alpaca_client(account)
254
-
279
+
255
280
  if order_data.order_type == OrderType.MARKET:
256
281
  alpaca_order = alpaca_client.place_market_order(
257
282
  symbol=order_data.symbol,
258
283
  quantity=order_data.quantity,
259
284
  side=order_data.side.value,
260
- time_in_force=order_data.time_in_force
285
+ time_in_force=order_data.time_in_force,
261
286
  )
262
287
  elif order_data.order_type == OrderType.LIMIT:
263
288
  alpaca_order = alpaca_client.place_limit_order(
@@ -265,26 +290,26 @@ class TradingService:
265
290
  quantity=order_data.quantity,
266
291
  side=order_data.side.value,
267
292
  limit_price=order_data.limit_price,
268
- time_in_force=order_data.time_in_force
293
+ time_in_force=order_data.time_in_force,
269
294
  )
270
295
  else:
271
296
  raise ValueError(f"Unsupported order type: {order_data.order_type}")
272
-
297
+
273
298
  order.alpaca_order_id = alpaca_order.id
274
299
  order.status = OrderStatus.SUBMITTED
275
300
  order.submitted_at = datetime.utcnow()
276
-
301
+
277
302
  self.db.commit()
278
303
  self.db.refresh(order)
279
-
304
+
280
305
  logger.info(f"Placed order {order.id} for portfolio {portfolio_id}")
281
306
  return order
282
-
307
+
283
308
  except Exception as e:
284
309
  self.db.rollback()
285
310
  logger.error(f"Failed to place order: {e}")
286
311
  raise
287
-
312
+
288
313
  def get_portfolio_positions(self, portfolio_id: UUID) -> List[PositionResponse]:
289
314
  """Get all positions for a portfolio"""
290
315
  positions = self.db.query(Position).filter(Position.portfolio_id == portfolio_id).all()
@@ -308,13 +333,15 @@ class TradingService:
308
333
  )
309
334
  for pos in positions
310
335
  ]
311
-
312
- def get_portfolio_orders(self, portfolio_id: UUID, status: Optional[OrderStatus] = None) -> List[OrderResponse]:
336
+
337
+ def get_portfolio_orders(
338
+ self, portfolio_id: UUID, status: Optional[OrderStatus] = None
339
+ ) -> List[OrderResponse]:
313
340
  """Get orders for a portfolio"""
314
341
  query = self.db.query(TradingOrder).filter(TradingOrder.portfolio_id == portfolio_id)
315
342
  if status:
316
343
  query = query.filter(TradingOrder.status == status)
317
-
344
+
318
345
  orders = query.order_by(desc(TradingOrder.created_at)).all()
319
346
  return [
320
347
  OrderResponse(
@@ -325,7 +352,9 @@ class TradingService:
325
352
  quantity=order.quantity,
326
353
  limit_price=float(order.limit_price) if order.limit_price else None,
327
354
  stop_price=float(order.stop_price) if order.stop_price else None,
328
- average_fill_price=float(order.average_fill_price) if order.average_fill_price else None,
355
+ average_fill_price=(
356
+ float(order.average_fill_price) if order.average_fill_price else None
357
+ ),
329
358
  status=order.status,
330
359
  filled_quantity=order.filled_quantity,
331
360
  remaining_quantity=order.remaining_quantity,
@@ -339,34 +368,41 @@ class TradingService:
339
368
  )
340
369
  for order in orders
341
370
  ]
342
-
371
+
343
372
  def get_portfolio_performance(self, portfolio_id: UUID, days: int = 30) -> pd.DataFrame:
344
373
  """Get portfolio performance history"""
345
374
  end_date = datetime.utcnow()
346
375
  start_date = end_date - timedelta(days=days)
347
-
348
- snapshots = self.db.query(PortfolioPerformanceSnapshot).filter(
349
- PortfolioPerformanceSnapshot.portfolio_id == portfolio_id,
350
- PortfolioPerformanceSnapshot.snapshot_date >= start_date
351
- ).order_by(PortfolioPerformanceSnapshot.snapshot_date).all()
352
-
376
+
377
+ snapshots = (
378
+ self.db.query(PortfolioPerformanceSnapshot)
379
+ .filter(
380
+ PortfolioPerformanceSnapshot.portfolio_id == portfolio_id,
381
+ PortfolioPerformanceSnapshot.snapshot_date >= start_date,
382
+ )
383
+ .order_by(PortfolioPerformanceSnapshot.snapshot_date)
384
+ .all()
385
+ )
386
+
353
387
  data = []
354
388
  for snapshot in snapshots:
355
- data.append({
356
- "date": snapshot.snapshot_date,
357
- "portfolio_value": float(snapshot.portfolio_value),
358
- "cash_balance": float(snapshot.cash_balance),
359
- "daily_return": float(snapshot.daily_return),
360
- "daily_return_pct": snapshot.daily_return_pct,
361
- "total_return": float(snapshot.total_return),
362
- "total_return_pct": snapshot.total_return_pct,
363
- "volatility": snapshot.volatility,
364
- "sharpe_ratio": snapshot.sharpe_ratio,
365
- "max_drawdown": snapshot.max_drawdown,
366
- })
367
-
389
+ data.append(
390
+ {
391
+ "date": snapshot.snapshot_date,
392
+ "portfolio_value": float(snapshot.portfolio_value),
393
+ "cash_balance": float(snapshot.cash_balance),
394
+ "daily_return": float(snapshot.daily_return),
395
+ "daily_return_pct": snapshot.daily_return_pct,
396
+ "total_return": float(snapshot.total_return),
397
+ "total_return_pct": snapshot.total_return_pct,
398
+ "volatility": snapshot.volatility,
399
+ "sharpe_ratio": snapshot.sharpe_ratio,
400
+ "max_drawdown": snapshot.max_drawdown,
401
+ }
402
+ )
403
+
368
404
  return pd.DataFrame(data)
369
-
405
+
370
406
  def create_trading_signal(
371
407
  self,
372
408
  portfolio_id: UUID,
@@ -379,12 +415,14 @@ class TradingService:
379
415
  stop_loss: Optional[float] = None,
380
416
  take_profit: Optional[float] = None,
381
417
  position_size: Optional[float] = None,
382
- expires_hours: int = 24
418
+ expires_hours: int = 24,
383
419
  ) -> TradingSignal:
384
420
  """Create a trading signal"""
385
421
  try:
386
- expires_at = datetime.utcnow() + timedelta(hours=expires_hours) if expires_hours > 0 else None
387
-
422
+ expires_at = (
423
+ datetime.utcnow() + timedelta(hours=expires_hours) if expires_hours > 0 else None
424
+ )
425
+
388
426
  signal = TradingSignal(
389
427
  portfolio_id=portfolio_id,
390
428
  symbol=symbol,
@@ -398,27 +436,32 @@ class TradingService:
398
436
  position_size=position_size,
399
437
  expires_at=expires_at,
400
438
  )
401
-
439
+
402
440
  self.db.add(signal)
403
441
  self.db.commit()
404
442
  self.db.refresh(signal)
405
-
443
+
406
444
  logger.info(f"Created trading signal {signal.id} for {symbol}")
407
445
  return signal
408
-
446
+
409
447
  except Exception as e:
410
448
  self.db.rollback()
411
449
  logger.error(f"Failed to create trading signal: {e}")
412
450
  raise
413
-
451
+
414
452
  def get_active_signals(self, portfolio_id: UUID) -> List[TradingSignalResponse]:
415
453
  """Get active trading signals for a portfolio"""
416
- signals = self.db.query(TradingSignal).filter(
417
- TradingSignal.portfolio_id == portfolio_id,
418
- TradingSignal.is_active == True,
419
- TradingSignal.expires_at > datetime.utcnow()
420
- ).order_by(desc(TradingSignal.created_at)).all()
421
-
454
+ signals = (
455
+ self.db.query(TradingSignal)
456
+ .filter(
457
+ TradingSignal.portfolio_id == portfolio_id,
458
+ TradingSignal.is_active == True,
459
+ TradingSignal.expires_at > datetime.utcnow(),
460
+ )
461
+ .order_by(desc(TradingSignal.created_at))
462
+ .all()
463
+ )
464
+
422
465
  return [
423
466
  TradingSignalResponse(
424
467
  id=signal.id,
@@ -438,16 +481,16 @@ class TradingService:
438
481
  )
439
482
  for signal in signals
440
483
  ]
441
-
484
+
442
485
  def calculate_portfolio_metrics(self, portfolio_id: UUID) -> Dict:
443
486
  """Calculate portfolio performance metrics"""
444
487
  portfolio = self.get_portfolio(portfolio_id)
445
488
  if not portfolio:
446
489
  return {}
447
-
490
+
448
491
  # Get performance history
449
492
  performance_df = self.get_portfolio_performance(portfolio_id, days=90)
450
-
493
+
451
494
  if performance_df.empty:
452
495
  return {
453
496
  "total_return": 0.0,
@@ -458,16 +501,18 @@ class TradingService:
458
501
  "current_value": float(portfolio.current_value),
459
502
  "cash_balance": float(portfolio.cash_balance),
460
503
  }
461
-
504
+
462
505
  # Calculate metrics
463
506
  returns = performance_df["daily_return_pct"].dropna()
464
-
507
+
465
508
  total_return = performance_df["total_return"].iloc[-1] if not performance_df.empty else 0
466
- total_return_pct = performance_df["total_return_pct"].iloc[-1] if not performance_df.empty else 0
467
- volatility = returns.std() * (252 ** 0.5) if len(returns) > 1 else 0 # Annualized
509
+ total_return_pct = (
510
+ performance_df["total_return_pct"].iloc[-1] if not performance_df.empty else 0
511
+ )
512
+ volatility = returns.std() * (252**0.5) if len(returns) > 1 else 0 # Annualized
468
513
  sharpe_ratio = (returns.mean() * 252) / volatility if volatility > 0 else 0 # Annualized
469
514
  max_drawdown = performance_df["max_drawdown"].max() if not performance_df.empty else 0
470
-
515
+
471
516
  return {
472
517
  "total_return": total_return,
473
518
  "total_return_pct": total_return_pct,
@@ -477,4 +522,4 @@ class TradingService:
477
522
  "current_value": float(portfolio.current_value),
478
523
  "cash_balance": float(portfolio.cash_balance),
479
524
  "num_positions": len(self.get_portfolio_positions(portfolio_id)),
480
- }
525
+ }
@@ -452,9 +452,7 @@ def save_model(
452
452
  "model_architecture": {
453
453
  "input_size": model.network[0].in_features,
454
454
  "hidden_layers": [
455
- layer.out_features
456
- for layer in model.network
457
- if isinstance(layer, nn.Linear)
455
+ layer.out_features for layer in model.network if isinstance(layer, nn.Linear)
458
456
  ][:-1],
459
457
  },
460
458
  "scaler_mean": scaler.mean_.tolist(),