from typing import Dict, Any from PySide6.QtCore import QObject, Signal from loguru import logger from storage.models import Position class Portfolio(QObject): """Manages the trading portfolio, including positions and P&L.""" class Signals(QObject): position_updated = Signal(Position) portfolio_value_updated = Signal(float) def __init__(self, db_session): super().__init__() self.db_session = db_session self.positions: Dict[str, Position] = {} self.signals = self.Signals() self._load_positions_from_db() def _load_positions_from_db(self): """Loads active positions from the database on initialization.""" try: active_positions = self.db_session.query(Position).filter_by(is_open=True).all() for pos in active_positions: self.positions[pos.symbol] = pos self.signals.position_updated.emit(pos) # Emit for initial UI load logger.info(f"Loaded {len(active_positions)} active positions from DB.") except Exception as e: logger.error(f"Error loading positions from DB: {e}") def update_position( self, symbol: str, quantity: int, average_price: float, current_price: float, is_open: bool = True, ): """ Adds or updates a position in the portfolio. """ position = self.positions.get(symbol) if position: position.quantity = quantity position.average_price = average_price position.current_price = current_price position.market_value = quantity * current_price position.unrealized_pnl = (current_price - average_price) * quantity position.is_open = is_open else: position = Position( symbol=symbol, quantity=quantity, average_price=average_price, current_price=current_price, market_value=quantity * current_price, unrealized_pnl=(current_price - average_price) * quantity, is_open=is_open, ) self.db_session.add(position) self.positions[symbol] = position try: self.db_session.commit() self.signals.position_updated.emit(position) logger.info(f"Position updated for {symbol}: Qty={quantity}, AvgPrice={average_price}, CurrentPrice={current_price}") except Exception as e: self.db_session.rollback() logger.error(f"Failed to save position for {symbol}: {e}") def get_position(self, symbol: str) -> Position | None: """ Returns the current position for a given symbol. """ return self.positions.get(symbol) def get_all_positions(self) -> Dict[str, Position]: """ Returns all active positions. """ return {s: p for s, p in self.positions.items() if p.is_open} def calculate_total_pnl(self) -> float: """ Calculates the total unrealized P&L of the portfolio. """ total_pnl = sum(pos.unrealized_pnl for pos in self.positions.values() if pos.is_open) self.signals.portfolio_value_updated.emit(total_pnl) return total_pnl