初始化回测核心引擎骨架

This commit is contained in:
zsb
2026-04-06 23:56:37 -07:00
commit 334864cbc5
25 changed files with 2878 additions and 0 deletions

View File

@@ -0,0 +1,11 @@
[package]
name = "fidc-core"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
[dependencies]
chrono.workspace = true
serde.workspace = true
thiserror.workspace = true

View File

@@ -0,0 +1,390 @@
use std::collections::{BTreeMap, BTreeSet};
use chrono::NaiveDate;
use crate::cost::CostModel;
use crate::data::{DataSet, PriceField};
use crate::engine::BacktestError;
use crate::events::{AccountEvent, FillEvent, OrderEvent, OrderSide, OrderStatus, PositionEvent};
use crate::portfolio::PortfolioState;
use crate::rules::EquityRuleHooks;
use crate::strategy::StrategyDecision;
#[derive(Debug, Default)]
pub struct BrokerExecutionReport {
pub order_events: Vec<OrderEvent>,
pub fill_events: Vec<FillEvent>,
pub position_events: Vec<PositionEvent>,
pub account_events: Vec<AccountEvent>,
}
pub struct BrokerSimulator<C, R> {
cost_model: C,
rules: R,
board_lot_size: u32,
}
impl<C, R> BrokerSimulator<C, R> {
pub fn new(cost_model: C, rules: R) -> Self {
Self {
cost_model,
rules,
board_lot_size: 100,
}
}
}
impl<C, R> BrokerSimulator<C, R>
where
C: CostModel,
R: EquityRuleHooks,
{
pub fn execute(
&self,
date: NaiveDate,
portfolio: &mut PortfolioState,
data: &DataSet,
decision: &StrategyDecision,
) -> Result<BrokerExecutionReport, BacktestError> {
let mut report = BrokerExecutionReport::default();
let target_quantities = if decision.rebalance {
self.target_quantities(date, portfolio, data, &decision.target_weights)?
} else {
BTreeMap::new()
};
let mut sell_symbols = BTreeSet::new();
sell_symbols.extend(portfolio.positions().keys().cloned());
sell_symbols.extend(decision.exit_symbols.iter().cloned());
sell_symbols.extend(target_quantities.keys().cloned());
for symbol in sell_symbols {
let current_qty = portfolio.position(&symbol).map(|pos| pos.quantity).unwrap_or(0);
if current_qty == 0 {
continue;
}
let target_qty = if decision.exit_symbols.contains(&symbol) {
0
} else if decision.rebalance {
*target_quantities.get(&symbol).unwrap_or(&0)
} else {
current_qty
};
if current_qty > target_qty {
let requested_qty = current_qty - target_qty;
self.process_sell(
date,
portfolio,
data,
&symbol,
requested_qty,
sell_reason(decision, &symbol),
&mut report,
)?;
}
}
if decision.rebalance {
for (symbol, target_qty) in target_quantities {
let current_qty = portfolio.position(&symbol).map(|pos| pos.quantity).unwrap_or(0);
if target_qty > current_qty {
let requested_qty = target_qty - current_qty;
self.process_buy(
date,
portfolio,
data,
&symbol,
requested_qty,
"rebalance_buy",
&mut report,
)?;
}
}
}
portfolio.prune_flat_positions();
Ok(report)
}
fn target_quantities(
&self,
date: NaiveDate,
portfolio: &PortfolioState,
data: &DataSet,
target_weights: &BTreeMap<String, f64>,
) -> Result<BTreeMap<String, u32>, BacktestError> {
let equity = self.total_equity_at(date, portfolio, data, PriceField::Open)?;
let mut targets = BTreeMap::new();
for (symbol, weight) in target_weights {
let price = data
.price(date, symbol, PriceField::Open)
.ok_or_else(|| BacktestError::MissingPrice {
date,
symbol: symbol.clone(),
field: "open",
})?;
let raw_qty = ((equity * weight) / price).floor() as u32;
let rounded_qty = self.round_buy_quantity(raw_qty);
targets.insert(symbol.clone(), rounded_qty);
}
Ok(targets)
}
fn process_sell(
&self,
date: NaiveDate,
portfolio: &mut PortfolioState,
data: &DataSet,
symbol: &str,
requested_qty: u32,
reason: &str,
report: &mut BrokerExecutionReport,
) -> Result<(), BacktestError> {
let snapshot = data.require_market(date, symbol)?;
let candidate = data.require_candidate(date, symbol)?;
let Some(position) = portfolio.position(symbol) else {
return Ok(());
};
let rule = self.rules.can_sell(date, snapshot, candidate, position);
if !rule.allowed {
report.order_events.push(OrderEvent {
date,
symbol: symbol.to_string(),
side: OrderSide::Sell,
requested_quantity: requested_qty,
filled_quantity: 0,
status: OrderStatus::Rejected,
reason: format!("{reason}: {}", rule.reason.unwrap_or_default()),
});
return Ok(());
}
let sellable = position.sellable_qty(date);
let filled_qty = requested_qty.min(sellable);
if filled_qty == 0 {
report.order_events.push(OrderEvent {
date,
symbol: symbol.to_string(),
side: OrderSide::Sell,
requested_quantity: requested_qty,
filled_quantity: 0,
status: OrderStatus::Rejected,
reason: format!("{reason}: no sellable quantity"),
});
return Ok(());
}
let cash_before = portfolio.cash();
let gross_amount = snapshot.open * filled_qty as f64;
let cost = self.cost_model.calculate(OrderSide::Sell, gross_amount);
let net_cash = gross_amount - cost.total();
let realized_pnl = portfolio
.position_mut(symbol)
.sell(filled_qty, snapshot.open)
.map_err(BacktestError::Execution)?;
portfolio.apply_cash_delta(net_cash);
let status = if filled_qty < requested_qty {
OrderStatus::PartiallyFilled
} else {
OrderStatus::Filled
};
report.order_events.push(OrderEvent {
date,
symbol: symbol.to_string(),
side: OrderSide::Sell,
requested_quantity: requested_qty,
filled_quantity: filled_qty,
status,
reason: reason.to_string(),
});
report.fill_events.push(FillEvent {
date,
symbol: symbol.to_string(),
side: OrderSide::Sell,
quantity: filled_qty,
price: snapshot.open,
gross_amount,
commission: cost.commission,
stamp_tax: cost.stamp_tax,
net_cash_flow: net_cash,
reason: reason.to_string(),
});
report.position_events.push(PositionEvent {
date,
symbol: symbol.to_string(),
delta_quantity: -(filled_qty as i32),
quantity_after: portfolio.position(symbol).map(|pos| pos.quantity).unwrap_or(0),
average_cost: portfolio
.position(symbol)
.map(|pos| pos.average_cost)
.unwrap_or(0.0),
realized_pnl_delta: realized_pnl,
reason: reason.to_string(),
});
report.account_events.push(AccountEvent {
date,
cash_before,
cash_after: portfolio.cash(),
total_equity: self.total_equity_at(date, portfolio, data, PriceField::Open)?,
note: format!("sell {symbol} {reason}"),
});
Ok(())
}
fn process_buy(
&self,
date: NaiveDate,
portfolio: &mut PortfolioState,
data: &DataSet,
symbol: &str,
requested_qty: u32,
reason: &str,
report: &mut BrokerExecutionReport,
) -> Result<(), BacktestError> {
let snapshot = data.require_market(date, symbol)?;
let candidate = data.require_candidate(date, symbol)?;
let rule = self.rules.can_buy(date, snapshot, candidate);
if !rule.allowed {
report.order_events.push(OrderEvent {
date,
symbol: symbol.to_string(),
side: OrderSide::Buy,
requested_quantity: requested_qty,
filled_quantity: 0,
status: OrderStatus::Rejected,
reason: format!("{reason}: {}", rule.reason.unwrap_or_default()),
});
return Ok(());
}
let filled_qty =
self.affordable_buy_quantity(portfolio.cash(), snapshot.open, requested_qty);
if filled_qty == 0 {
report.order_events.push(OrderEvent {
date,
symbol: symbol.to_string(),
side: OrderSide::Buy,
requested_quantity: requested_qty,
filled_quantity: 0,
status: OrderStatus::Rejected,
reason: format!("{reason}: insufficient cash after fees"),
});
return Ok(());
}
let cash_before = portfolio.cash();
let gross_amount = snapshot.open * filled_qty as f64;
let cost = self.cost_model.calculate(OrderSide::Buy, gross_amount);
let cash_out = gross_amount + cost.total();
portfolio.apply_cash_delta(-cash_out);
portfolio.position_mut(symbol).buy(date, filled_qty, snapshot.open);
let status = if filled_qty < requested_qty {
OrderStatus::PartiallyFilled
} else {
OrderStatus::Filled
};
report.order_events.push(OrderEvent {
date,
symbol: symbol.to_string(),
side: OrderSide::Buy,
requested_quantity: requested_qty,
filled_quantity: filled_qty,
status,
reason: reason.to_string(),
});
report.fill_events.push(FillEvent {
date,
symbol: symbol.to_string(),
side: OrderSide::Buy,
quantity: filled_qty,
price: snapshot.open,
gross_amount,
commission: cost.commission,
stamp_tax: cost.stamp_tax,
net_cash_flow: -cash_out,
reason: reason.to_string(),
});
report.position_events.push(PositionEvent {
date,
symbol: symbol.to_string(),
delta_quantity: filled_qty as i32,
quantity_after: portfolio.position(symbol).map(|pos| pos.quantity).unwrap_or(0),
average_cost: portfolio
.position(symbol)
.map(|pos| pos.average_cost)
.unwrap_or(0.0),
realized_pnl_delta: 0.0,
reason: reason.to_string(),
});
report.account_events.push(AccountEvent {
date,
cash_before,
cash_after: portfolio.cash(),
total_equity: self.total_equity_at(date, portfolio, data, PriceField::Open)?,
note: format!("buy {symbol} {reason}"),
});
Ok(())
}
fn total_equity_at(
&self,
date: NaiveDate,
portfolio: &PortfolioState,
data: &DataSet,
field: PriceField,
) -> Result<f64, BacktestError> {
let mut market_value = 0.0;
for position in portfolio.positions().values() {
let price = data
.price(date, &position.symbol, field)
.ok_or_else(|| BacktestError::MissingPrice {
date,
symbol: position.symbol.clone(),
field: match field {
PriceField::Open => "open",
PriceField::Close => "close",
},
})?;
market_value += price * position.quantity as f64;
}
Ok(portfolio.cash() + market_value)
}
fn round_buy_quantity(&self, quantity: u32) -> u32 {
(quantity / self.board_lot_size) * self.board_lot_size
}
fn affordable_buy_quantity(&self, cash: f64, price: f64, requested_qty: u32) -> u32 {
let mut quantity = self.round_buy_quantity(requested_qty);
while quantity > 0 {
let gross = price * quantity as f64;
let cost = self.cost_model.calculate(OrderSide::Buy, gross);
if gross + cost.total() <= cash + 1e-6 {
return quantity;
}
quantity = quantity.saturating_sub(self.board_lot_size);
}
0
}
}
fn sell_reason(decision: &StrategyDecision, symbol: &str) -> &'static str {
if decision.exit_symbols.contains(symbol) {
"exit_hook_sell"
} else {
"rebalance_sell"
}
}

View File

@@ -0,0 +1,58 @@
use std::collections::HashMap;
use chrono::NaiveDate;
#[derive(Debug, Clone)]
pub struct TradingCalendar {
days: Vec<NaiveDate>,
index: HashMap<NaiveDate, usize>,
}
impl TradingCalendar {
pub fn new(mut days: Vec<NaiveDate>) -> Self {
days.sort_unstable();
days.dedup();
let index = days
.iter()
.copied()
.enumerate()
.map(|(idx, day)| (day, idx))
.collect();
Self { days, index }
}
pub fn days(&self) -> &[NaiveDate] {
&self.days
}
pub fn iter(&self) -> impl Iterator<Item = NaiveDate> + '_ {
self.days.iter().copied()
}
pub fn len(&self) -> usize {
self.days.len()
}
pub fn is_empty(&self) -> bool {
self.days.is_empty()
}
pub fn index_of(&self, date: NaiveDate) -> Option<usize> {
self.index.get(&date).copied()
}
pub fn previous_day(&self, date: NaiveDate) -> Option<NaiveDate> {
let idx = self.index_of(date)?;
idx.checked_sub(1).and_then(|prev| self.days.get(prev).copied())
}
pub fn trailing_days(&self, end: NaiveDate, lookback: usize) -> Vec<NaiveDate> {
let Some(end_idx) = self.index_of(end) else {
return Vec::new();
};
let start = end_idx.saturating_add(1).saturating_sub(lookback);
self.days[start..=end_idx].to_vec()
}
}

View File

@@ -0,0 +1,56 @@
use crate::events::OrderSide;
#[derive(Debug, Clone, Copy)]
pub struct TradingCost {
pub commission: f64,
pub stamp_tax: f64,
}
impl TradingCost {
pub fn total(self) -> f64 {
self.commission + self.stamp_tax
}
}
pub trait CostModel {
fn calculate(&self, side: OrderSide, gross_amount: f64) -> TradingCost;
}
#[derive(Debug, Clone, Copy)]
pub struct ChinaAShareCostModel {
pub commission_rate: f64,
pub stamp_tax_rate: f64,
pub minimum_commission: f64,
}
impl Default for ChinaAShareCostModel {
fn default() -> Self {
Self {
commission_rate: 0.0003,
stamp_tax_rate: 0.001,
minimum_commission: 5.0,
}
}
}
impl CostModel for ChinaAShareCostModel {
fn calculate(&self, side: OrderSide, gross_amount: f64) -> TradingCost {
if gross_amount <= 0.0 {
return TradingCost {
commission: 0.0,
stamp_tax: 0.0,
};
}
let commission = (gross_amount * self.commission_rate).max(self.minimum_commission);
let stamp_tax = match side {
OrderSide::Buy => 0.0,
OrderSide::Sell => gross_amount * self.stamp_tax_rate,
};
TradingCost {
commission,
stamp_tax,
}
}
}

View File

@@ -0,0 +1,471 @@
use std::collections::{BTreeMap, HashMap};
use std::fs;
use std::path::Path;
use chrono::NaiveDate;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::calendar::TradingCalendar;
use crate::instrument::Instrument;
mod date_format {
use chrono::NaiveDate;
use serde::{self, Deserialize, Deserializer, Serializer};
const FORMAT: &str = "%Y-%m-%d";
pub fn serialize<S>(date: &NaiveDate, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&date.format(FORMAT).to_string())
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<NaiveDate, D::Error>
where
D: Deserializer<'de>,
{
let text = String::deserialize(deserializer)?;
NaiveDate::parse_from_str(&text, FORMAT).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Error)]
pub enum DataSetError {
#[error("failed to read file {path}: {source}")]
Io {
path: String,
#[source]
source: std::io::Error,
},
#[error("invalid csv row in {path} at line {line}: {message}")]
InvalidRow {
path: String,
line: usize,
message: String,
},
#[error("benchmark file contains multiple benchmark codes")]
MultipleBenchmarks,
#[error("missing data for {kind} on {date} / {symbol}")]
MissingSnapshot {
kind: &'static str,
date: NaiveDate,
symbol: String,
},
#[error("benchmark snapshot missing for {date}")]
MissingBenchmark { date: NaiveDate },
}
#[derive(Debug, Clone, Copy)]
pub enum PriceField {
Open,
Close,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DailyMarketSnapshot {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub symbol: String,
pub open: f64,
pub high: f64,
pub low: f64,
pub close: f64,
pub prev_close: f64,
pub volume: u64,
pub paused: bool,
pub upper_limit: f64,
pub lower_limit: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DailyFactorSnapshot {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub symbol: String,
pub market_cap_bn: f64,
pub free_float_cap_bn: f64,
pub pe_ttm: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkSnapshot {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub benchmark: String,
pub open: f64,
pub close: f64,
pub prev_close: f64,
pub volume: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CandidateEligibility {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub symbol: String,
pub is_st: bool,
pub is_new_listing: bool,
pub is_paused: bool,
pub allow_buy: bool,
pub allow_sell: bool,
}
impl CandidateEligibility {
pub fn eligible_for_selection(&self) -> bool {
!self.is_st && !self.is_new_listing && !self.is_paused && self.allow_buy && self.allow_sell
}
}
#[derive(Debug, Clone)]
pub struct DataSet {
instruments: HashMap<String, Instrument>,
calendar: TradingCalendar,
market_by_date: BTreeMap<NaiveDate, Vec<DailyMarketSnapshot>>,
market_index: HashMap<(NaiveDate, String), DailyMarketSnapshot>,
factor_by_date: BTreeMap<NaiveDate, Vec<DailyFactorSnapshot>>,
factor_index: HashMap<(NaiveDate, String), DailyFactorSnapshot>,
candidate_by_date: BTreeMap<NaiveDate, Vec<CandidateEligibility>>,
candidate_index: HashMap<(NaiveDate, String), CandidateEligibility>,
benchmark_by_date: BTreeMap<NaiveDate, BenchmarkSnapshot>,
benchmark_code: String,
}
impl DataSet {
pub fn from_csv_dir(path: &Path) -> Result<Self, DataSetError> {
let instruments = read_instruments(&path.join("instruments.csv"))?;
let market = read_market(&path.join("market.csv"))?;
let factors = read_factors(&path.join("factors.csv"))?;
let candidates = read_candidates(&path.join("candidate_flags.csv"))?;
let benchmarks = read_benchmarks(&path.join("benchmark.csv"))?;
let benchmark_code = collect_benchmark_code(&benchmarks)?;
let calendar = TradingCalendar::new(benchmarks.iter().map(|item| item.date).collect());
let instruments = instruments
.into_iter()
.map(|instrument| (instrument.symbol.clone(), instrument))
.collect::<HashMap<_, _>>();
let market_by_date = group_by_date(market.clone(), |item| item.date);
let market_index = market
.into_iter()
.map(|item| ((item.date, item.symbol.clone()), item))
.collect::<HashMap<_, _>>();
let factor_by_date = group_by_date(factors.clone(), |item| item.date);
let factor_index = factors
.into_iter()
.map(|item| ((item.date, item.symbol.clone()), item))
.collect::<HashMap<_, _>>();
let candidate_by_date = group_by_date(candidates.clone(), |item| item.date);
let candidate_index = candidates
.into_iter()
.map(|item| ((item.date, item.symbol.clone()), item))
.collect::<HashMap<_, _>>();
let benchmark_by_date = benchmarks
.into_iter()
.map(|item| (item.date, item))
.collect::<BTreeMap<_, _>>();
Ok(Self {
instruments,
calendar,
market_by_date,
market_index,
factor_by_date,
factor_index,
candidate_by_date,
candidate_index,
benchmark_by_date,
benchmark_code,
})
}
pub fn calendar(&self) -> &TradingCalendar {
&self.calendar
}
pub fn benchmark_code(&self) -> &str {
&self.benchmark_code
}
pub fn instruments(&self) -> &HashMap<String, Instrument> {
&self.instruments
}
pub fn market(&self, date: NaiveDate, symbol: &str) -> Option<&DailyMarketSnapshot> {
self.market_index.get(&(date, symbol.to_string()))
}
pub fn factor(&self, date: NaiveDate, symbol: &str) -> Option<&DailyFactorSnapshot> {
self.factor_index.get(&(date, symbol.to_string()))
}
pub fn candidate(&self, date: NaiveDate, symbol: &str) -> Option<&CandidateEligibility> {
self.candidate_index.get(&(date, symbol.to_string()))
}
pub fn benchmark(&self, date: NaiveDate) -> Option<&BenchmarkSnapshot> {
self.benchmark_by_date.get(&date)
}
pub fn benchmark_series(&self) -> Vec<BenchmarkSnapshot> {
self.benchmark_by_date.values().cloned().collect()
}
pub fn price(&self, date: NaiveDate, symbol: &str, field: PriceField) -> Option<f64> {
let snapshot = self.market(date, symbol)?;
Some(match field {
PriceField::Open => snapshot.open,
PriceField::Close => snapshot.close,
})
}
pub fn factor_snapshots_on(&self, date: NaiveDate) -> Vec<&DailyFactorSnapshot> {
self.factor_by_date
.get(&date)
.map(|rows| rows.iter().collect())
.unwrap_or_default()
}
pub fn market_snapshots_on(&self, date: NaiveDate) -> Vec<&DailyMarketSnapshot> {
self.market_by_date
.get(&date)
.map(|rows| rows.iter().collect())
.unwrap_or_default()
}
pub fn candidate_snapshots_on(&self, date: NaiveDate) -> Vec<&CandidateEligibility> {
self.candidate_by_date
.get(&date)
.map(|rows| rows.iter().collect())
.unwrap_or_default()
}
pub fn benchmark_closes_up_to(&self, date: NaiveDate, lookback: usize) -> Vec<f64> {
self.calendar
.trailing_days(date, lookback)
.into_iter()
.filter_map(|day| self.benchmark(day).map(|row| row.close))
.collect()
}
pub fn require_market(
&self,
date: NaiveDate,
symbol: &str,
) -> Result<&DailyMarketSnapshot, DataSetError> {
self.market(date, symbol).ok_or_else(|| DataSetError::MissingSnapshot {
kind: "market",
date,
symbol: symbol.to_string(),
})
}
pub fn require_candidate(
&self,
date: NaiveDate,
symbol: &str,
) -> Result<&CandidateEligibility, DataSetError> {
self.candidate(date, symbol)
.ok_or_else(|| DataSetError::MissingSnapshot {
kind: "candidate",
date,
symbol: symbol.to_string(),
})
}
}
fn read_instruments(path: &Path) -> Result<Vec<Instrument>, DataSetError> {
let rows = read_rows(path)?;
let mut instruments = Vec::new();
for row in rows {
instruments.push(Instrument {
symbol: row.get(0)?.to_string(),
name: row.get(1)?.to_string(),
board: row.get(2)?.to_string(),
});
}
Ok(instruments)
}
fn read_market(path: &Path) -> Result<Vec<DailyMarketSnapshot>, DataSetError> {
let rows = read_rows(path)?;
let mut snapshots = Vec::new();
for row in rows {
let prev_close = row.parse_f64(6)?;
snapshots.push(DailyMarketSnapshot {
date: row.parse_date(0)?,
symbol: row.get(1)?.to_string(),
open: row.parse_f64(2)?,
high: row.parse_f64(3)?,
low: row.parse_f64(4)?,
close: row.parse_f64(5)?,
prev_close,
volume: row.parse_u64(7)?,
paused: row.parse_bool(8)?,
upper_limit: round2(prev_close * 1.10),
lower_limit: round2(prev_close * 0.90),
});
}
Ok(snapshots)
}
fn read_factors(path: &Path) -> Result<Vec<DailyFactorSnapshot>, DataSetError> {
let rows = read_rows(path)?;
let mut snapshots = Vec::new();
for row in rows {
snapshots.push(DailyFactorSnapshot {
date: row.parse_date(0)?,
symbol: row.get(1)?.to_string(),
market_cap_bn: row.parse_f64(2)?,
free_float_cap_bn: row.parse_f64(3)?,
pe_ttm: row.parse_f64(4)?,
});
}
Ok(snapshots)
}
fn read_candidates(path: &Path) -> Result<Vec<CandidateEligibility>, DataSetError> {
let rows = read_rows(path)?;
let mut snapshots = Vec::new();
for row in rows {
snapshots.push(CandidateEligibility {
date: row.parse_date(0)?,
symbol: row.get(1)?.to_string(),
is_st: row.parse_bool(2)?,
is_new_listing: row.parse_bool(3)?,
is_paused: row.parse_bool(4)?,
allow_buy: row.parse_bool(5)?,
allow_sell: row.parse_bool(6)?,
});
}
Ok(snapshots)
}
fn read_benchmarks(path: &Path) -> Result<Vec<BenchmarkSnapshot>, DataSetError> {
let rows = read_rows(path)?;
let mut snapshots = Vec::new();
for row in rows {
snapshots.push(BenchmarkSnapshot {
date: row.parse_date(0)?,
benchmark: row.get(1)?.to_string(),
open: row.parse_f64(2)?,
close: row.parse_f64(3)?,
prev_close: row.parse_f64(4)?,
volume: row.parse_u64(5)?,
});
}
Ok(snapshots)
}
struct CsvRow {
path: String,
line: usize,
fields: Vec<String>,
}
impl CsvRow {
fn get(&self, index: usize) -> Result<&str, DataSetError> {
self.fields.get(index).map(String::as_str).ok_or_else(|| DataSetError::InvalidRow {
path: self.path.clone(),
line: self.line,
message: format!("missing column {index}"),
})
}
fn parse_date(&self, index: usize) -> Result<NaiveDate, DataSetError> {
NaiveDate::parse_from_str(self.get(index)?, "%Y-%m-%d").map_err(|err| DataSetError::InvalidRow {
path: self.path.clone(),
line: self.line,
message: format!("invalid date: {err}"),
})
}
fn parse_f64(&self, index: usize) -> Result<f64, DataSetError> {
self.get(index)?
.parse::<f64>()
.map_err(|err| DataSetError::InvalidRow {
path: self.path.clone(),
line: self.line,
message: format!("invalid f64: {err}"),
})
}
fn parse_u64(&self, index: usize) -> Result<u64, DataSetError> {
self.get(index)?
.parse::<u64>()
.map_err(|err| DataSetError::InvalidRow {
path: self.path.clone(),
line: self.line,
message: format!("invalid u64: {err}"),
})
}
fn parse_bool(&self, index: usize) -> Result<bool, DataSetError> {
self.get(index)?
.parse::<bool>()
.map_err(|err| DataSetError::InvalidRow {
path: self.path.clone(),
line: self.line,
message: format!("invalid bool: {err}"),
})
}
}
fn read_rows(path: &Path) -> Result<Vec<CsvRow>, DataSetError> {
let content = fs::read_to_string(path).map_err(|source| DataSetError::Io {
path: path.display().to_string(),
source,
})?;
let mut rows = Vec::new();
for (line_idx, line) in content.lines().enumerate() {
let line_no = line_idx + 1;
if line_no == 1 || line.trim().is_empty() {
continue;
}
rows.push(CsvRow {
path: path.display().to_string(),
line: line_no,
fields: line.split(',').map(|field| field.trim().to_string()).collect(),
});
}
Ok(rows)
}
fn group_by_date<T, F>(rows: Vec<T>, mut date_of: F) -> BTreeMap<NaiveDate, Vec<T>>
where
F: FnMut(&T) -> NaiveDate,
{
let mut grouped = BTreeMap::<NaiveDate, Vec<T>>::new();
for row in rows {
grouped.entry(date_of(&row)).or_default().push(row);
}
grouped
}
fn collect_benchmark_code(benchmarks: &[BenchmarkSnapshot]) -> Result<String, DataSetError> {
let mut codes = benchmarks
.iter()
.map(|row| row.benchmark.clone())
.collect::<Vec<_>>();
codes.sort_unstable();
codes.dedup();
if codes.len() == 1 {
Ok(codes.remove(0))
} else {
Err(DataSetError::MultipleBenchmarks)
}
}
fn round2(value: f64) -> f64 {
(value * 100.0).round() / 100.0
}

View File

@@ -0,0 +1,167 @@
use chrono::NaiveDate;
use serde::Serialize;
use thiserror::Error;
use crate::broker::{BrokerExecutionReport, BrokerSimulator};
use crate::cost::CostModel;
use crate::data::{BenchmarkSnapshot, DataSet, DataSetError, PriceField};
use crate::events::{AccountEvent, FillEvent, OrderEvent, PositionEvent};
use crate::portfolio::{HoldingSummary, PortfolioState};
use crate::rules::EquityRuleHooks;
use crate::strategy::{Strategy, StrategyContext, StrategyDecision};
#[derive(Debug, Error)]
pub enum BacktestError {
#[error(transparent)]
Data(#[from] DataSetError),
#[error("missing {field} price for {symbol} on {date}")]
MissingPrice {
date: NaiveDate,
symbol: String,
field: &'static str,
},
#[error("benchmark snapshot missing for {date}")]
MissingBenchmark { date: NaiveDate },
#[error("{0}")]
Execution(String),
}
#[derive(Debug, Clone)]
pub struct BacktestConfig {
pub initial_cash: f64,
pub benchmark_code: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct DailyEquityPoint {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub cash: f64,
pub market_value: f64,
pub total_equity: f64,
pub benchmark_close: f64,
pub notes: String,
}
#[derive(Debug, Clone)]
pub struct BacktestResult {
pub strategy_name: String,
pub equity_curve: Vec<DailyEquityPoint>,
pub benchmark_series: Vec<BenchmarkSnapshot>,
pub order_events: Vec<OrderEvent>,
pub fills: Vec<FillEvent>,
pub position_events: Vec<PositionEvent>,
pub account_events: Vec<AccountEvent>,
pub holdings_summary: Vec<HoldingSummary>,
}
pub struct BacktestEngine<S, C, R> {
data: DataSet,
strategy: S,
broker: BrokerSimulator<C, R>,
config: BacktestConfig,
}
impl<S, C, R> BacktestEngine<S, C, R> {
pub fn new(
data: DataSet,
strategy: S,
broker: BrokerSimulator<C, R>,
config: BacktestConfig,
) -> Self {
Self {
data,
strategy,
broker,
config,
}
}
}
impl<S, C, R> BacktestEngine<S, C, R>
where
S: Strategy,
C: CostModel,
R: EquityRuleHooks,
{
pub fn run(&mut self) -> Result<BacktestResult, BacktestError> {
let mut portfolio = PortfolioState::new(self.config.initial_cash);
let mut result = BacktestResult {
strategy_name: self.strategy.name().to_string(),
benchmark_series: self.data.benchmark_series(),
order_events: Vec::new(),
fills: Vec::new(),
position_events: Vec::new(),
account_events: Vec::new(),
equity_curve: Vec::new(),
holdings_summary: Vec::new(),
};
for execution_date in self.data.calendar().iter() {
let decision = match self.data.calendar().previous_day(execution_date) {
Some(decision_date) => {
let decision_index = self.data.calendar().index_of(decision_date).unwrap_or(0);
self.strategy.on_day(&StrategyContext {
execution_date,
decision_date,
decision_index,
data: &self.data,
portfolio: &portfolio,
})?
}
None => StrategyDecision::default(),
};
let report = self
.broker
.execute(execution_date, &mut portfolio, &self.data, &decision)?;
self.extend_result(&mut result, report);
portfolio.update_prices(execution_date, &self.data, PriceField::Close)?;
let benchmark = self
.data
.benchmark(execution_date)
.ok_or(BacktestError::MissingBenchmark {
date: execution_date,
})?;
let notes = decision.notes.join(" | ");
result.equity_curve.push(DailyEquityPoint {
date: execution_date,
cash: portfolio.cash(),
market_value: portfolio.market_value(),
total_equity: portfolio.total_equity(),
benchmark_close: benchmark.close,
notes,
});
}
if let Some(last_date) = self.data.calendar().days().last().copied() {
result.holdings_summary = portfolio.holdings_summary(last_date);
}
Ok(result)
}
fn extend_result(&self, result: &mut BacktestResult, report: BrokerExecutionReport) {
result.order_events.extend(report.order_events);
result.fills.extend(report.fill_events);
result.position_events.extend(report.position_events);
result.account_events.extend(report.account_events);
}
}
mod date_format {
use chrono::NaiveDate;
use serde::Serializer;
const FORMAT: &str = "%Y-%m-%d";
pub fn serialize<S>(date: &NaiveDate, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&date.format(FORMAT).to_string())
}
}

View File

@@ -0,0 +1,86 @@
use chrono::NaiveDate;
use serde::{Deserialize, Serialize};
mod date_format {
use chrono::NaiveDate;
use serde::{self, Deserialize, Deserializer, Serializer};
const FORMAT: &str = "%Y-%m-%d";
pub fn serialize<S>(date: &NaiveDate, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&date.format(FORMAT).to_string())
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<NaiveDate, D::Error>
where
D: Deserializer<'de>,
{
let text = String::deserialize(deserializer)?;
NaiveDate::parse_from_str(&text, FORMAT).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum OrderSide {
Buy,
Sell,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum OrderStatus {
Filled,
PartiallyFilled,
Rejected,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrderEvent {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub symbol: String,
pub side: OrderSide,
pub requested_quantity: u32,
pub filled_quantity: u32,
pub status: OrderStatus,
pub reason: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FillEvent {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub symbol: String,
pub side: OrderSide,
pub quantity: u32,
pub price: f64,
pub gross_amount: f64,
pub commission: f64,
pub stamp_tax: f64,
pub net_cash_flow: f64,
pub reason: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PositionEvent {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub symbol: String,
pub delta_quantity: i32,
pub quantity_after: u32,
pub average_cost: f64,
pub realized_pnl_delta: f64,
pub reason: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccountEvent {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub cash_before: f64,
pub cash_after: f64,
pub total_equity: f64,
pub note: String,
}

View File

@@ -0,0 +1,8 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Instrument {
pub symbol: String,
pub name: String,
pub board: String,
}

View File

@@ -0,0 +1,50 @@
pub mod broker;
pub mod calendar;
pub mod cost;
pub mod data;
pub mod engine;
pub mod events;
pub mod instrument;
pub mod portfolio;
pub mod rules;
pub mod strategy;
pub mod universe;
pub use broker::{BrokerExecutionReport, BrokerSimulator};
pub use calendar::TradingCalendar;
pub use cost::{ChinaAShareCostModel, CostModel, TradingCost};
pub use data::{
BenchmarkSnapshot,
CandidateEligibility,
DailyFactorSnapshot,
DailyMarketSnapshot,
DataSet,
DataSetError,
PriceField,
};
pub use engine::{BacktestConfig, BacktestEngine, BacktestError, BacktestResult, DailyEquityPoint};
pub use events::{
AccountEvent,
FillEvent,
OrderEvent,
OrderSide,
OrderStatus,
PositionEvent,
};
pub use instrument::Instrument;
pub use portfolio::{HoldingSummary, PortfolioState, Position};
pub use rules::{ChinaEquityRuleHooks, EquityRuleHooks, RuleCheck};
pub use strategy::{
CnSmallCapRotationConfig,
CnSmallCapRotationStrategy,
Strategy,
StrategyContext,
StrategyDecision,
};
pub use universe::{
BandRegime,
DynamicMarketCapBandSelector,
SelectionContext,
UniverseCandidate,
UniverseSelector,
};

View File

@@ -0,0 +1,242 @@
use std::collections::BTreeMap;
use chrono::NaiveDate;
use serde::Serialize;
use crate::data::{DataSet, DataSetError, PriceField};
#[derive(Debug, Clone)]
pub struct PositionLot {
pub acquired_date: NaiveDate,
pub quantity: u32,
pub price: f64,
}
#[derive(Debug, Clone)]
pub struct Position {
pub symbol: String,
pub quantity: u32,
pub average_cost: f64,
pub last_price: f64,
pub realized_pnl: f64,
lots: Vec<PositionLot>,
}
impl Position {
pub fn new(symbol: impl Into<String>) -> Self {
Self {
symbol: symbol.into(),
quantity: 0,
average_cost: 0.0,
last_price: 0.0,
realized_pnl: 0.0,
lots: Vec::new(),
}
}
pub fn is_flat(&self) -> bool {
self.quantity == 0
}
pub fn buy(&mut self, date: NaiveDate, quantity: u32, price: f64) {
if quantity == 0 {
return;
}
self.lots.push(PositionLot {
acquired_date: date,
quantity,
price,
});
self.quantity += quantity;
self.last_price = price;
self.recalculate_average_cost();
}
pub fn sell(&mut self, quantity: u32, price: f64) -> Result<f64, String> {
if quantity > self.quantity {
return Err(format!(
"sell quantity {} exceeds current quantity {} for {}",
quantity, self.quantity, self.symbol
));
}
let mut remaining = quantity;
let mut realized = 0.0;
while remaining > 0 {
let Some(first_lot) = self.lots.first_mut() else {
return Err(format!("position {} has no lots to sell", self.symbol));
};
let lot_sell = remaining.min(first_lot.quantity);
realized += (price - first_lot.price) * lot_sell as f64;
first_lot.quantity -= lot_sell;
remaining -= lot_sell;
if first_lot.quantity == 0 {
self.lots.remove(0);
}
}
self.quantity -= quantity;
self.last_price = price;
self.realized_pnl += realized;
self.recalculate_average_cost();
Ok(realized)
}
pub fn sellable_qty(&self, date: NaiveDate) -> u32 {
self.lots
.iter()
.filter(|lot| lot.acquired_date < date)
.map(|lot| lot.quantity)
.sum()
}
pub fn market_value(&self) -> f64 {
self.quantity as f64 * self.last_price
}
pub fn unrealized_pnl(&self) -> f64 {
(self.last_price - self.average_cost) * self.quantity as f64
}
pub fn holding_return(&self, price: f64) -> Option<f64> {
if self.quantity == 0 || self.average_cost <= 0.0 {
None
} else {
Some((price / self.average_cost) - 1.0)
}
}
fn recalculate_average_cost(&mut self) {
if self.quantity == 0 {
self.average_cost = 0.0;
return;
}
let total_cost = self
.lots
.iter()
.map(|lot| lot.price * lot.quantity as f64)
.sum::<f64>();
self.average_cost = total_cost / self.quantity as f64;
}
}
#[derive(Debug, Clone)]
pub struct PortfolioState {
cash: f64,
positions: BTreeMap<String, Position>,
}
impl PortfolioState {
pub fn new(initial_cash: f64) -> Self {
Self {
cash: initial_cash,
positions: BTreeMap::new(),
}
}
pub fn cash(&self) -> f64 {
self.cash
}
pub fn positions(&self) -> &BTreeMap<String, Position> {
&self.positions
}
pub fn position(&self, symbol: &str) -> Option<&Position> {
self.positions.get(symbol)
}
pub fn position_mut(&mut self, symbol: &str) -> &mut Position {
self.positions
.entry(symbol.to_string())
.or_insert_with(|| Position::new(symbol))
}
pub fn apply_cash_delta(&mut self, delta: f64) {
self.cash += delta;
}
pub fn prune_flat_positions(&mut self) {
self.positions.retain(|_, position| !position.is_flat());
}
pub fn update_prices(
&mut self,
date: NaiveDate,
data: &DataSet,
field: PriceField,
) -> Result<(), DataSetError> {
for position in self.positions.values_mut() {
let price = data
.price(date, &position.symbol, field)
.ok_or_else(|| DataSetError::MissingSnapshot {
kind: match field {
PriceField::Open => "open price",
PriceField::Close => "close price",
},
date,
symbol: position.symbol.clone(),
})?;
position.last_price = price;
}
Ok(())
}
pub fn market_value(&self) -> f64 {
self.positions.values().map(Position::market_value).sum()
}
pub fn total_equity(&self) -> f64 {
self.cash + self.market_value()
}
pub fn holdings_summary(&self, date: NaiveDate) -> Vec<HoldingSummary> {
self.positions
.values()
.filter(|position| position.quantity > 0)
.map(|position| HoldingSummary {
date,
symbol: position.symbol.clone(),
quantity: position.quantity,
average_cost: position.average_cost,
last_price: position.last_price,
market_value: position.market_value(),
unrealized_pnl: position.unrealized_pnl(),
realized_pnl: position.realized_pnl,
})
.collect()
}
}
#[derive(Debug, Clone, Serialize)]
pub struct HoldingSummary {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub symbol: String,
pub quantity: u32,
pub average_cost: f64,
pub last_price: f64,
pub market_value: f64,
pub unrealized_pnl: f64,
pub realized_pnl: f64,
}
mod date_format {
use chrono::NaiveDate;
use serde::Serializer;
const FORMAT: &str = "%Y-%m-%d";
pub fn serialize<S>(date: &NaiveDate, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&date.format(FORMAT).to_string())
}
}

View File

@@ -0,0 +1,100 @@
use chrono::NaiveDate;
use crate::data::{CandidateEligibility, DailyMarketSnapshot};
use crate::portfolio::Position;
#[derive(Debug, Clone)]
pub struct RuleCheck {
pub allowed: bool,
pub reason: Option<String>,
}
impl RuleCheck {
pub fn allow() -> Self {
Self {
allowed: true,
reason: None,
}
}
pub fn reject(reason: impl Into<String>) -> Self {
Self {
allowed: false,
reason: Some(reason.into()),
}
}
}
pub trait EquityRuleHooks {
fn can_buy(
&self,
execution_date: NaiveDate,
snapshot: &DailyMarketSnapshot,
candidate: &CandidateEligibility,
) -> RuleCheck;
fn can_sell(
&self,
execution_date: NaiveDate,
snapshot: &DailyMarketSnapshot,
candidate: &CandidateEligibility,
position: &Position,
) -> RuleCheck;
}
#[derive(Debug, Clone, Default)]
pub struct ChinaEquityRuleHooks;
impl ChinaEquityRuleHooks {
fn at_upper_limit(snapshot: &DailyMarketSnapshot) -> bool {
snapshot.open >= snapshot.upper_limit - 1e-6
}
fn at_lower_limit(snapshot: &DailyMarketSnapshot) -> bool {
snapshot.open <= snapshot.lower_limit + 1e-6
}
}
impl EquityRuleHooks for ChinaEquityRuleHooks {
fn can_buy(
&self,
_execution_date: NaiveDate,
snapshot: &DailyMarketSnapshot,
candidate: &CandidateEligibility,
) -> RuleCheck {
if snapshot.paused || candidate.is_paused {
return RuleCheck::reject("paused");
}
if !candidate.allow_buy {
return RuleCheck::reject("buy disabled by eligibility flags");
}
if Self::at_upper_limit(snapshot) {
return RuleCheck::reject("open at or above upper limit");
}
RuleCheck::allow()
}
fn can_sell(
&self,
execution_date: NaiveDate,
snapshot: &DailyMarketSnapshot,
candidate: &CandidateEligibility,
position: &Position,
) -> RuleCheck {
if snapshot.paused || candidate.is_paused {
return RuleCheck::reject("paused");
}
if !candidate.allow_sell {
return RuleCheck::reject("sell disabled by eligibility flags");
}
if Self::at_lower_limit(snapshot) {
return RuleCheck::reject("open at or below lower limit");
}
if position.sellable_qty(execution_date) == 0 {
return RuleCheck::reject("t+1 sellable quantity is zero");
}
RuleCheck::allow()
}
}

View File

@@ -0,0 +1,192 @@
use std::collections::{BTreeMap, BTreeSet};
use chrono::NaiveDate;
use crate::data::{DataSet, PriceField};
use crate::engine::BacktestError;
use crate::portfolio::PortfolioState;
use crate::universe::{DynamicMarketCapBandSelector, SelectionContext, UniverseSelector};
pub trait Strategy {
fn name(&self) -> &'static str;
fn on_day(&mut self, ctx: &StrategyContext<'_>) -> Result<StrategyDecision, BacktestError>;
}
pub struct StrategyContext<'a> {
pub execution_date: NaiveDate,
pub decision_date: NaiveDate,
pub decision_index: usize,
pub data: &'a DataSet,
pub portfolio: &'a PortfolioState,
}
#[derive(Debug, Clone, Default)]
pub struct StrategyDecision {
pub rebalance: bool,
pub target_weights: BTreeMap<String, f64>,
pub exit_symbols: BTreeSet<String>,
pub notes: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct CnSmallCapRotationConfig {
pub rebalance_every_n_days: usize,
pub max_positions: usize,
pub short_ma_days: usize,
pub long_ma_days: usize,
pub stop_loss_pct: f64,
pub take_profit_pct: f64,
}
impl CnSmallCapRotationConfig {
pub fn demo() -> Self {
Self {
rebalance_every_n_days: 3,
max_positions: 2,
short_ma_days: 3,
long_ma_days: 5,
stop_loss_pct: 0.08,
take_profit_pct: 0.10,
}
}
}
pub struct CnSmallCapRotationStrategy {
config: CnSmallCapRotationConfig,
selector: DynamicMarketCapBandSelector,
last_gross_exposure: Option<f64>,
}
impl CnSmallCapRotationStrategy {
pub fn new(config: CnSmallCapRotationConfig) -> Self {
Self {
selector: DynamicMarketCapBandSelector::demo(config.max_positions),
config,
last_gross_exposure: None,
}
}
fn moving_average(values: &[f64], lookback: usize) -> f64 {
let len = values.len();
let window = values.iter().skip(len.saturating_sub(lookback));
let (sum, count) = window.fold((0.0, 0usize), |(sum, count), value| (sum + value, count + 1));
if count == 0 {
0.0
} else {
sum / count as f64
}
}
fn gross_exposure(&self, closes: &[f64]) -> f64 {
if closes.is_empty() {
return 0.0;
}
let current = *closes.last().unwrap_or(&0.0);
let short_ma = Self::moving_average(closes, self.config.short_ma_days);
let long_ma = Self::moving_average(closes, self.config.long_ma_days);
if current >= long_ma && short_ma >= long_ma {
1.0
} else if current >= long_ma || short_ma >= long_ma {
0.5
} else {
0.0
}
}
fn stop_exit_symbols(&self, ctx: &StrategyContext<'_>) -> Result<BTreeSet<String>, BacktestError> {
let mut exits = BTreeSet::new();
for position in ctx.portfolio.positions().values() {
if position.quantity == 0 {
continue;
}
let close_price = ctx
.data
.price(ctx.decision_date, &position.symbol, PriceField::Close)
.ok_or_else(|| BacktestError::MissingPrice {
date: ctx.decision_date,
symbol: position.symbol.clone(),
field: "close",
})?;
let Some(holding_return) = position.holding_return(close_price) else {
continue;
};
if holding_return <= -self.config.stop_loss_pct
|| holding_return >= self.config.take_profit_pct
{
exits.insert(position.symbol.clone());
}
}
Ok(exits)
}
}
impl Strategy for CnSmallCapRotationStrategy {
fn name(&self) -> &'static str {
"cn-smallcap-rotation"
}
fn on_day(&mut self, ctx: &StrategyContext<'_>) -> Result<StrategyDecision, BacktestError> {
let benchmark = ctx
.data
.benchmark(ctx.decision_date)
.ok_or(BacktestError::MissingBenchmark {
date: ctx.decision_date,
})?;
let benchmark_closes = ctx
.data
.benchmark_closes_up_to(ctx.decision_date, self.config.long_ma_days);
let gross_exposure = self.gross_exposure(&benchmark_closes);
let periodic_rebalance = ctx.decision_index % self.config.rebalance_every_n_days == 0;
let exposure_changed = self
.last_gross_exposure
.map(|previous| (previous - gross_exposure).abs() > f64::EPSILON)
.unwrap_or(true);
let exit_symbols = self.stop_exit_symbols(ctx)?;
let rebalance = periodic_rebalance || exposure_changed;
let mut target_weights = BTreeMap::new();
let mut notes = vec![format!(
"decision={} exec={} exposure={:.2}",
ctx.decision_date, ctx.execution_date, gross_exposure
)];
if rebalance && gross_exposure > 0.0 {
let selected = self.selector.select(&SelectionContext {
decision_date: ctx.decision_date,
benchmark,
data: ctx.data,
});
if !selected.is_empty() {
let per_name_weight = gross_exposure / selected.len() as f64;
for candidate in selected {
target_weights.insert(candidate.symbol.clone(), per_name_weight);
}
}
notes.push(format!("rebalance names={}", target_weights.len()));
}
if !exit_symbols.is_empty() {
notes.push(format!("exit hooks={}", exit_symbols.len()));
}
if rebalance && gross_exposure == 0.0 {
notes.push("risk throttle forced all-cash".to_string());
}
self.last_gross_exposure = Some(gross_exposure);
Ok(StrategyDecision {
rebalance,
target_weights,
exit_symbols,
notes,
})
}
}

View File

@@ -0,0 +1,110 @@
use chrono::NaiveDate;
use crate::data::{BenchmarkSnapshot, DataSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BandRegime {
Bullish,
Neutral,
Defensive,
}
#[derive(Debug, Clone)]
pub struct UniverseCandidate {
pub symbol: String,
pub market_cap_bn: f64,
pub free_float_cap_bn: f64,
}
pub struct SelectionContext<'a> {
pub decision_date: NaiveDate,
pub benchmark: &'a BenchmarkSnapshot,
pub data: &'a DataSet,
}
pub trait UniverseSelector {
fn select(&self, ctx: &SelectionContext<'_>) -> Vec<UniverseCandidate>;
}
#[derive(Debug, Clone)]
pub struct DynamicMarketCapBandSelector {
pub base_index_level: f64,
pub bullish_threshold: f64,
pub neutral_threshold: f64,
pub bullish_band: (f64, f64),
pub neutral_band: (f64, f64),
pub defensive_band: (f64, f64),
pub top_n: usize,
}
impl DynamicMarketCapBandSelector {
pub fn demo(top_n: usize) -> Self {
Self {
base_index_level: 3000.0,
bullish_threshold: 1.02,
neutral_threshold: 1.0,
bullish_band: (30.0, 60.0),
neutral_band: (40.0, 90.0),
defensive_band: (60.0, 120.0),
top_n,
}
}
pub fn regime(&self, benchmark_level: f64) -> BandRegime {
let ratio = benchmark_level / self.base_index_level;
if ratio >= self.bullish_threshold {
BandRegime::Bullish
} else if ratio >= self.neutral_threshold {
BandRegime::Neutral
} else {
BandRegime::Defensive
}
}
fn band(&self, regime: BandRegime) -> (f64, f64) {
match regime {
BandRegime::Bullish => self.bullish_band,
BandRegime::Neutral => self.neutral_band,
BandRegime::Defensive => self.defensive_band,
}
}
}
impl UniverseSelector for DynamicMarketCapBandSelector {
fn select(&self, ctx: &SelectionContext<'_>) -> Vec<UniverseCandidate> {
let regime = self.regime(ctx.benchmark.close);
let (min_cap, max_cap) = self.band(regime);
let mut selected = ctx
.data
.factor_snapshots_on(ctx.decision_date)
.into_iter()
.filter_map(|factor| {
let candidate = ctx.data.candidate(ctx.decision_date, &factor.symbol)?;
let market = ctx.data.market(ctx.decision_date, &factor.symbol)?;
if !candidate.eligible_for_selection() || market.paused {
return None;
}
if factor.market_cap_bn < min_cap || factor.market_cap_bn > max_cap {
return None;
}
Some(UniverseCandidate {
symbol: factor.symbol.clone(),
market_cap_bn: factor.market_cap_bn,
free_float_cap_bn: factor.free_float_cap_bn,
})
})
.collect::<Vec<_>>();
selected.sort_by(|left, right| {
left.market_cap_bn
.partial_cmp(&right.market_cap_bn)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| left.symbol.cmp(&right.symbol))
});
selected.truncate(self.top_n);
selected
}
}

View File

@@ -0,0 +1,103 @@
use chrono::NaiveDate;
use fidc_core::cost::CostModel;
use fidc_core::rules::EquityRuleHooks;
use fidc_core::{
CandidateEligibility,
ChinaAShareCostModel,
ChinaEquityRuleHooks,
DailyMarketSnapshot,
OrderSide,
Position,
};
fn d(year: i32, month: u32, day: u32) -> NaiveDate {
NaiveDate::from_ymd_opt(year, month, day).expect("valid date")
}
fn candidate() -> CandidateEligibility {
CandidateEligibility {
date: d(2024, 1, 3),
symbol: "000001.SZ".to_string(),
is_st: false,
is_new_listing: false,
is_paused: false,
allow_buy: true,
allow_sell: true,
}
}
fn snapshot(open: f64, upper_limit: f64, lower_limit: f64) -> DailyMarketSnapshot {
DailyMarketSnapshot {
date: d(2024, 1, 3),
symbol: "000001.SZ".to_string(),
open,
high: open,
low: open,
close: open,
prev_close: 10.0,
volume: 1_000_000,
paused: false,
upper_limit,
lower_limit,
}
}
#[test]
fn china_cost_model_applies_minimum_commission_and_stamp_tax() {
let model = ChinaAShareCostModel::default();
let buy = model.calculate(OrderSide::Buy, 1_000.0);
assert!((buy.commission - 5.0).abs() < 1e-9);
assert_eq!(buy.stamp_tax, 0.0);
let sell = model.calculate(OrderSide::Sell, 100_000.0);
assert!((sell.commission - 30.0).abs() < 1e-9);
assert!((sell.stamp_tax - 100.0).abs() < 1e-9);
}
#[test]
fn china_rule_hooks_block_same_day_sell_under_t_plus_one() {
let hooks = ChinaEquityRuleHooks;
let mut position = Position::new("000001.SZ");
let trade_date = d(2024, 1, 3);
position.buy(trade_date, 1_000, 10.0);
let check = hooks.can_sell(
trade_date,
&snapshot(10.1, 11.0, 9.0),
&candidate(),
&position,
);
assert!(!check.allowed);
assert!(check
.reason
.as_deref()
.unwrap_or_default()
.contains("t+1"));
}
#[test]
fn china_rule_hooks_block_buy_at_limit_up_and_sell_at_limit_down() {
let hooks = ChinaEquityRuleHooks;
let candidate = candidate();
let mut position = Position::new("000001.SZ");
position.buy(d(2024, 1, 2), 1_000, 10.0);
let buy_check = hooks.can_buy(d(2024, 1, 3), &snapshot(11.0, 11.0, 9.0), &candidate);
assert!(!buy_check.allowed);
assert!(buy_check
.reason
.as_deref()
.unwrap_or_default()
.contains("upper limit"));
let sell_check =
hooks.can_sell(d(2024, 1, 3), &snapshot(9.0, 11.0, 9.0), &candidate, &position);
assert!(!sell_check.allowed);
assert!(sell_check
.reason
.as_deref()
.unwrap_or_default()
.contains("lower limit"));
}