90 lines
3.3 KiB
Python
90 lines
3.3 KiB
Python
from typing import Dict, Any
|
|
from PySide6.QtCore import QObject, Signal
|
|
from loguru import logger
|
|
from storage.models import Position
|
|
|
|
|
|
class Portfolio(QObject):
|
|
"""Manages the trading portfolio, including positions and P&L."""
|
|
|
|
class Signals(QObject):
|
|
position_updated = Signal(Position)
|
|
portfolio_value_updated = Signal(float)
|
|
|
|
def __init__(self, db_session):
|
|
super().__init__()
|
|
self.db_session = db_session
|
|
self.positions: Dict[str, Position] = {}
|
|
self.signals = self.Signals()
|
|
self._load_positions_from_db()
|
|
|
|
def _load_positions_from_db(self):
|
|
"""Loads active positions from the database on initialization."""
|
|
try:
|
|
active_positions = self.db_session.query(Position).filter_by(is_open=True).all()
|
|
for pos in active_positions:
|
|
self.positions[pos.symbol] = pos
|
|
self.signals.position_updated.emit(pos) # Emit for initial UI load
|
|
logger.info(f"Loaded {len(active_positions)} active positions from DB.")
|
|
except Exception as e:
|
|
logger.error(f"Error loading positions from DB: {e}")
|
|
|
|
def update_position(
|
|
self,
|
|
symbol: str,
|
|
quantity: int,
|
|
average_price: float,
|
|
current_price: float,
|
|
is_open: bool = True,
|
|
):
|
|
"""
|
|
Adds or updates a position in the portfolio.
|
|
"""
|
|
position = self.positions.get(symbol)
|
|
if position:
|
|
position.quantity = quantity
|
|
position.average_price = average_price
|
|
position.current_price = current_price
|
|
position.market_value = quantity * current_price
|
|
position.unrealized_pnl = (current_price - average_price) * quantity
|
|
position.is_open = is_open
|
|
else:
|
|
position = Position(
|
|
symbol=symbol,
|
|
quantity=quantity,
|
|
average_price=average_price,
|
|
current_price=current_price,
|
|
market_value=quantity * current_price,
|
|
unrealized_pnl=(current_price - average_price) * quantity,
|
|
is_open=is_open,
|
|
)
|
|
self.db_session.add(position)
|
|
|
|
self.positions[symbol] = position
|
|
try:
|
|
self.db_session.commit()
|
|
self.signals.position_updated.emit(position)
|
|
logger.info(f"Position updated for {symbol}: Qty={quantity}, AvgPrice={average_price}, CurrentPrice={current_price}")
|
|
except Exception as e:
|
|
self.db_session.rollback()
|
|
logger.error(f"Failed to save position for {symbol}: {e}")
|
|
|
|
def get_position(self, symbol: str) -> Position | None:
|
|
"""
|
|
Returns the current position for a given symbol.
|
|
"""
|
|
return self.positions.get(symbol)
|
|
|
|
def get_all_positions(self) -> Dict[str, Position]:
|
|
"""
|
|
Returns all active positions.
|
|
"""
|
|
return {s: p for s, p in self.positions.items() if p.is_open}
|
|
|
|
def calculate_total_pnl(self) -> float:
|
|
"""
|
|
Calculates the total unrealized P&L of the portfolio.
|
|
"""
|
|
total_pnl = sum(pos.unrealized_pnl for pos in self.positions.values() if pos.is_open)
|
|
self.signals.portfolio_value_updated.emit(total_pnl)
|
|
return total_pnl |