use chrono::NaiveDate; use indexmap::IndexMap; use serde::Serialize; use crate::data::{DataSet, DataSetError, PriceField}; #[derive(Debug, Clone)] pub struct PositionLot { pub acquired_date: NaiveDate, pub quantity: u32, pub price: f64, } #[derive(Debug, Clone)] pub struct Position { pub symbol: String, pub quantity: u32, pub average_cost: f64, pub last_price: f64, pub realized_pnl: f64, lots: Vec, } impl Position { pub fn new(symbol: impl Into) -> Self { Self { symbol: symbol.into(), quantity: 0, average_cost: 0.0, last_price: 0.0, realized_pnl: 0.0, lots: Vec::new(), } } pub fn is_flat(&self) -> bool { self.quantity == 0 } pub fn buy(&mut self, date: NaiveDate, quantity: u32, price: f64) { if quantity == 0 { return; } self.lots.push(PositionLot { acquired_date: date, quantity, price, }); self.quantity += quantity; self.last_price = price; self.recalculate_average_cost(); } pub fn sell(&mut self, quantity: u32, price: f64) -> Result { if quantity > self.quantity { return Err(format!( "sell quantity {} exceeds current quantity {} for {}", quantity, self.quantity, self.symbol )); } let mut remaining = quantity; let mut realized = 0.0; while remaining > 0 { let Some(first_lot) = self.lots.first_mut() else { return Err(format!("position {} has no lots to sell", self.symbol)); }; let lot_sell = remaining.min(first_lot.quantity); realized += (price - first_lot.price) * lot_sell as f64; first_lot.quantity -= lot_sell; remaining -= lot_sell; if first_lot.quantity == 0 { self.lots.remove(0); } } self.quantity -= quantity; self.last_price = price; self.realized_pnl += realized; self.recalculate_average_cost(); Ok(realized) } pub fn sellable_qty(&self, date: NaiveDate) -> u32 { self.lots .iter() .filter(|lot| lot.acquired_date < date) .map(|lot| lot.quantity) .sum() } pub fn market_value(&self) -> f64 { self.quantity as f64 * self.last_price } pub fn unrealized_pnl(&self) -> f64 { (self.last_price - self.average_cost) * self.quantity as f64 } pub fn holding_return(&self, price: f64) -> Option { if self.quantity == 0 || self.average_cost <= 0.0 { None } else { Some((price / self.average_cost) - 1.0) } } fn recalculate_average_cost(&mut self) { if self.quantity == 0 { self.average_cost = 0.0; return; } let total_cost = self .lots .iter() .map(|lot| lot.price * lot.quantity as f64) .sum::(); self.average_cost = total_cost / self.quantity as f64; } pub fn apply_cash_dividend(&mut self, dividend_per_share: f64) -> f64 { if self.quantity == 0 || !dividend_per_share.is_finite() || dividend_per_share == 0.0 { return 0.0; } for lot in &mut self.lots { lot.price -= dividend_per_share; } self.average_cost -= dividend_per_share; self.last_price -= dividend_per_share; self.quantity as f64 * dividend_per_share } pub fn apply_split_ratio(&mut self, ratio: f64) -> i32 { if self.quantity == 0 || !ratio.is_finite() || ratio <= 0.0 || (ratio - 1.0).abs() < 1e-9 { return 0; } let old_quantity = self.quantity; let mut scaled_lots = self .lots .iter() .map(|lot| PositionLot { acquired_date: lot.acquired_date, quantity: round_half_up_u32(lot.quantity as f64 * ratio), price: lot.price / ratio, }) .collect::>(); let expected_total = round_half_up_u32(old_quantity as f64 * ratio); let scaled_total = scaled_lots.iter().map(|lot| lot.quantity).sum::(); if let Some(last_lot) = scaled_lots.last_mut() { if scaled_total < expected_total { last_lot.quantity += expected_total - scaled_total; } else if scaled_total > expected_total { last_lot.quantity = last_lot .quantity .saturating_sub(scaled_total - expected_total); } } scaled_lots.retain(|lot| lot.quantity > 0); self.lots = scaled_lots; self.quantity = self.lots.iter().map(|lot| lot.quantity).sum(); self.last_price /= ratio; self.recalculate_average_cost(); self.quantity as i32 - old_quantity as i32 } } #[derive(Debug, Clone)] pub struct PortfolioState { cash: f64, positions: IndexMap, cash_receivables: Vec, } #[derive(Debug, Clone)] pub(crate) struct SuccessorConversionOutcome { pub old_symbol: String, pub new_symbol: String, pub old_quantity: u32, pub new_quantity_delta: i32, pub new_quantity_after: u32, pub new_average_cost_after: f64, pub cash_delta: f64, } impl PortfolioState { pub fn new(initial_cash: f64) -> Self { Self { cash: initial_cash, positions: IndexMap::new(), cash_receivables: Vec::new(), } } pub fn cash(&self) -> f64 { self.cash } pub fn positions(&self) -> &IndexMap { &self.positions } pub fn position(&self, symbol: &str) -> Option<&Position> { self.positions.get(symbol) } pub fn position_mut_if_exists(&mut self, symbol: &str) -> Option<&mut Position> { self.positions.get_mut(symbol) } pub fn position_mut(&mut self, symbol: &str) -> &mut Position { self.positions .entry(symbol.to_string()) .or_insert_with(|| Position::new(symbol)) } pub fn apply_cash_delta(&mut self, delta: f64) { self.cash += delta; } pub fn prune_flat_positions(&mut self) { self.positions.retain(|_, position| !position.is_flat()); } pub fn add_cash_receivable(&mut self, receivable: CashReceivable) { self.cash_receivables.push(receivable); } pub fn settle_cash_receivables(&mut self, date: NaiveDate) -> Vec { let mut settled = Vec::new(); let mut pending = Vec::new(); for receivable in self.cash_receivables.drain(..) { if receivable.payable_date <= date { self.cash += receivable.amount; settled.push(receivable); } else { pending.push(receivable); } } self.cash_receivables = pending; settled } pub fn cash_receivables(&self) -> &[CashReceivable] { &self.cash_receivables } pub fn update_prices( &mut self, date: NaiveDate, data: &DataSet, field: PriceField, ) -> Result<(), DataSetError> { for position in self.positions.values_mut() { let price = data.price(date, &position.symbol, field).ok_or_else(|| { DataSetError::MissingSnapshot { kind: match field { PriceField::DayOpen => "day open price", PriceField::Open => "open price", PriceField::Close => "close price", PriceField::Last => "last price", }, date, symbol: position.symbol.clone(), } })?; position.last_price = price; } Ok(()) } pub fn market_value(&self) -> f64 { self.positions.values().map(Position::market_value).sum() } pub fn total_equity(&self) -> f64 { self.cash + self.market_value() } pub fn holdings_summary(&self, date: NaiveDate) -> Vec { self.positions .values() .filter(|position| position.quantity > 0) .map(|position| HoldingSummary { date, symbol: position.symbol.clone(), quantity: position.quantity, average_cost: position.average_cost, last_price: position.last_price, market_value: position.market_value(), unrealized_pnl: position.unrealized_pnl(), realized_pnl: position.realized_pnl, }) .collect() } pub(crate) fn apply_successor_conversion( &mut self, old_symbol: &str, new_symbol: &str, ratio: f64, cash_per_old_share: f64, ) -> Option { if !ratio.is_finite() || ratio <= 0.0 { return None; } let old_symbol_owned = old_symbol.to_string(); let old_position = self.positions.shift_remove(old_symbol)?; if old_position.quantity == 0 { return None; } let old_quantity = old_position.quantity; let last_price = old_position.last_price; let realized_pnl = old_position.realized_pnl; let mut converted_lots = old_position .lots .into_iter() .map(|lot| PositionLot { acquired_date: lot.acquired_date, quantity: round_half_up_u32(lot.quantity as f64 * ratio), price: lot.price / ratio, }) .collect::>(); let expected_total = round_half_up_u32(old_quantity as f64 * ratio); let scaled_total = converted_lots.iter().map(|lot| lot.quantity).sum::(); if let Some(last_lot) = converted_lots.last_mut() { if scaled_total < expected_total { last_lot.quantity += expected_total - scaled_total; } else if scaled_total > expected_total { last_lot.quantity = last_lot .quantity .saturating_sub(scaled_total - expected_total); } } converted_lots.retain(|lot| lot.quantity > 0); let converted_quantity = converted_lots.iter().map(|lot| lot.quantity).sum::(); let converted_last_price = if last_price > 0.0 { last_price / ratio } else { 0.0 }; let successor = self .positions .entry(new_symbol.to_string()) .or_insert_with(|| Position::new(new_symbol)); successor.lots.extend(converted_lots); successor.quantity = successor.lots.iter().map(|lot| lot.quantity).sum(); successor.realized_pnl += realized_pnl; if converted_last_price > 0.0 { successor.last_price = converted_last_price; } successor.recalculate_average_cost(); Some(SuccessorConversionOutcome { old_symbol: old_symbol_owned, new_symbol: new_symbol.to_string(), old_quantity, new_quantity_delta: converted_quantity as i32, new_quantity_after: successor.quantity, new_average_cost_after: successor.average_cost, cash_delta: if cash_per_old_share.is_finite() { old_quantity as f64 * cash_per_old_share } else { 0.0 }, }) } } #[cfg(test)] mod tests { use super::*; #[test] fn positions_preserve_insertion_order() { let date = NaiveDate::from_ymd_opt(2025, 1, 2).unwrap(); let mut portfolio = PortfolioState::new(10_000.0); portfolio.position_mut("603657.SH").buy(date, 100, 10.0); portfolio.position_mut("001266.SZ").buy(date, 100, 10.0); portfolio.position_mut("601798.SH").buy(date, 100, 10.0); let symbols = portfolio.positions().keys().cloned().collect::>(); assert_eq!( symbols, vec![ "603657.SH".to_string(), "001266.SZ".to_string(), "601798.SH".to_string() ] ); } } #[derive(Debug, Clone, Serialize)] pub struct HoldingSummary { #[serde(with = "date_format")] pub date: NaiveDate, pub symbol: String, pub quantity: u32, pub average_cost: f64, pub last_price: f64, pub market_value: f64, pub unrealized_pnl: f64, pub realized_pnl: f64, } #[derive(Debug, Clone)] pub struct CashReceivable { pub symbol: String, pub ex_date: NaiveDate, pub payable_date: NaiveDate, pub amount: f64, pub reason: String, } mod date_format { use chrono::NaiveDate; use serde::Serializer; const FORMAT: &str = "%Y-%m-%d"; pub fn serialize(date: &NaiveDate, serializer: S) -> Result where S: Serializer, { serializer.serialize_str(&date.format(FORMAT).to_string()) } } fn round_half_up_u32(value: f64) -> u32 { if !value.is_finite() || value <= 0.0 { 0 } else { value.round() as u32 } }