use std::collections::BTreeMap; use chrono::NaiveDate; use crate::events::OrderSide; pub const STOCK_PIT_TAX_CHANGE_DATE: (i32, u32, u32) = (2023, 8, 28); #[derive(Debug, Clone, Copy)] pub struct TradingCost { pub commission: f64, pub stamp_tax: f64, } impl TradingCost { pub fn total(self) -> f64 { self.commission + self.stamp_tax } } pub trait CostModel { fn calculate(&self, date: NaiveDate, side: OrderSide, gross_amount: f64) -> TradingCost; fn calculate_with_order_state( &self, date: NaiveDate, side: OrderSide, gross_amount: f64, _order_id: Option, _commission_state: &mut BTreeMap, ) -> TradingCost { self.calculate(date, side, gross_amount) } } #[derive(Debug, Clone, Copy)] pub struct ChinaAShareCostModel { pub commission_rate: f64, pub stamp_tax_rate_before_change: f64, pub stamp_tax_rate_after_change: f64, pub minimum_commission: f64, } impl Default for ChinaAShareCostModel { fn default() -> Self { Self { commission_rate: 0.0003, stamp_tax_rate_before_change: 0.001, stamp_tax_rate_after_change: 0.0005, minimum_commission: 5.0, } } } impl ChinaAShareCostModel { pub fn commission_for(&self, gross_amount: f64) -> f64 { if gross_amount <= 0.0 { return 0.0; } (gross_amount * self.commission_rate).max(self.minimum_commission) } pub fn stamp_tax_rate_for(&self, date: NaiveDate) -> f64 { let change_date = NaiveDate::from_ymd_opt( STOCK_PIT_TAX_CHANGE_DATE.0, STOCK_PIT_TAX_CHANGE_DATE.1, STOCK_PIT_TAX_CHANGE_DATE.2, ) .expect("valid pit tax change date"); if date < change_date { self.stamp_tax_rate_before_change } else { self.stamp_tax_rate_after_change } } pub fn stamp_tax_for(&self, date: NaiveDate, side: OrderSide, gross_amount: f64) -> f64 { if gross_amount <= 0.0 || side == OrderSide::Buy { return 0.0; } gross_amount * self.stamp_tax_rate_for(date) } pub fn commission_for_order_fill( &self, gross_amount: f64, order_id: Option, commission_state: &mut BTreeMap, ) -> f64 { if gross_amount <= 0.0 { return 0.0; } let raw_commission = gross_amount * self.commission_rate; let Some(order_id) = order_id else { return raw_commission.max(self.minimum_commission); }; let remaining_minimum = commission_state .entry(order_id) .or_insert(self.minimum_commission); if raw_commission > *remaining_minimum { let charged = if (*remaining_minimum - self.minimum_commission).abs() < 1e-12 { raw_commission } else { raw_commission - *remaining_minimum }; *remaining_minimum = 0.0; charged } else { let charged = if (*remaining_minimum - self.minimum_commission).abs() < 1e-12 { self.minimum_commission } else { 0.0 }; *remaining_minimum -= raw_commission; charged } } } impl CostModel for ChinaAShareCostModel { fn calculate(&self, date: NaiveDate, side: OrderSide, gross_amount: f64) -> TradingCost { if gross_amount <= 0.0 { return TradingCost { commission: 0.0, stamp_tax: 0.0, }; } let commission = self.commission_for(gross_amount); let stamp_tax = self.stamp_tax_for(date, side, gross_amount); TradingCost { commission, stamp_tax, } } fn calculate_with_order_state( &self, date: NaiveDate, side: OrderSide, gross_amount: f64, order_id: Option, commission_state: &mut BTreeMap, ) -> TradingCost { if gross_amount <= 0.0 { return TradingCost { commission: 0.0, stamp_tax: 0.0, }; } let commission = self.commission_for_order_fill(gross_amount, order_id, commission_state); let stamp_tax = self.stamp_tax_for(date, side, gross_amount); TradingCost { commission, stamp_tax, } } }