37172-vm/algo_trader/core/portfolio.py
2025-12-27 04:26:41 +00:00

90 lines
3.3 KiB
Python

from typing import Dict, Any
from PySide6.QtCore import QObject, Signal
from loguru import logger
from storage.models import Position
class Portfolio(QObject):
"""Manages the trading portfolio, including positions and P&L."""
class Signals(QObject):
position_updated = Signal(Position)
portfolio_value_updated = Signal(float)
def __init__(self, db_session):
super().__init__()
self.db_session = db_session
self.positions: Dict[str, Position] = {}
self.signals = self.Signals()
self._load_positions_from_db()
def _load_positions_from_db(self):
"""Loads active positions from the database on initialization."""
try:
active_positions = self.db_session.query(Position).filter_by(is_open=True).all()
for pos in active_positions:
self.positions[pos.symbol] = pos
self.signals.position_updated.emit(pos) # Emit for initial UI load
logger.info(f"Loaded {len(active_positions)} active positions from DB.")
except Exception as e:
logger.error(f"Error loading positions from DB: {e}")
def update_position(
self,
symbol: str,
quantity: int,
average_price: float,
current_price: float,
is_open: bool = True,
):
"""
Adds or updates a position in the portfolio.
"""
position = self.positions.get(symbol)
if position:
position.quantity = quantity
position.average_price = average_price
position.current_price = current_price
position.market_value = quantity * current_price
position.unrealized_pnl = (current_price - average_price) * quantity
position.is_open = is_open
else:
position = Position(
symbol=symbol,
quantity=quantity,
average_price=average_price,
current_price=current_price,
market_value=quantity * current_price,
unrealized_pnl=(current_price - average_price) * quantity,
is_open=is_open,
)
self.db_session.add(position)
self.positions[symbol] = position
try:
self.db_session.commit()
self.signals.position_updated.emit(position)
logger.info(f"Position updated for {symbol}: Qty={quantity}, AvgPrice={average_price}, CurrentPrice={current_price}")
except Exception as e:
self.db_session.rollback()
logger.error(f"Failed to save position for {symbol}: {e}")
def get_position(self, symbol: str) -> Position | None:
"""
Returns the current position for a given symbol.
"""
return self.positions.get(symbol)
def get_all_positions(self) -> Dict[str, Position]:
"""
Returns all active positions.
"""
return {s: p for s, p in self.positions.items() if p.is_open}
def calculate_total_pnl(self) -> float:
"""
Calculates the total unrealized P&L of the portfolio.
"""
total_pnl = sum(pos.unrealized_pnl for pos in self.positions.values() if pos.is_open)
self.signals.portfolio_value_updated.emit(total_pnl)
return total_pnl