初始化回测核心引擎骨架

This commit is contained in:
zsb
2026-04-06 23:56:37 -07:00
commit 334864cbc5
25 changed files with 2878 additions and 0 deletions

View File

@@ -0,0 +1,471 @@
use std::collections::{BTreeMap, HashMap};
use std::fs;
use std::path::Path;
use chrono::NaiveDate;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::calendar::TradingCalendar;
use crate::instrument::Instrument;
mod date_format {
use chrono::NaiveDate;
use serde::{self, Deserialize, Deserializer, Serializer};
const FORMAT: &str = "%Y-%m-%d";
pub fn serialize<S>(date: &NaiveDate, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&date.format(FORMAT).to_string())
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<NaiveDate, D::Error>
where
D: Deserializer<'de>,
{
let text = String::deserialize(deserializer)?;
NaiveDate::parse_from_str(&text, FORMAT).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Error)]
pub enum DataSetError {
#[error("failed to read file {path}: {source}")]
Io {
path: String,
#[source]
source: std::io::Error,
},
#[error("invalid csv row in {path} at line {line}: {message}")]
InvalidRow {
path: String,
line: usize,
message: String,
},
#[error("benchmark file contains multiple benchmark codes")]
MultipleBenchmarks,
#[error("missing data for {kind} on {date} / {symbol}")]
MissingSnapshot {
kind: &'static str,
date: NaiveDate,
symbol: String,
},
#[error("benchmark snapshot missing for {date}")]
MissingBenchmark { date: NaiveDate },
}
#[derive(Debug, Clone, Copy)]
pub enum PriceField {
Open,
Close,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DailyMarketSnapshot {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub symbol: String,
pub open: f64,
pub high: f64,
pub low: f64,
pub close: f64,
pub prev_close: f64,
pub volume: u64,
pub paused: bool,
pub upper_limit: f64,
pub lower_limit: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DailyFactorSnapshot {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub symbol: String,
pub market_cap_bn: f64,
pub free_float_cap_bn: f64,
pub pe_ttm: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkSnapshot {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub benchmark: String,
pub open: f64,
pub close: f64,
pub prev_close: f64,
pub volume: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CandidateEligibility {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub symbol: String,
pub is_st: bool,
pub is_new_listing: bool,
pub is_paused: bool,
pub allow_buy: bool,
pub allow_sell: bool,
}
impl CandidateEligibility {
pub fn eligible_for_selection(&self) -> bool {
!self.is_st && !self.is_new_listing && !self.is_paused && self.allow_buy && self.allow_sell
}
}
#[derive(Debug, Clone)]
pub struct DataSet {
instruments: HashMap<String, Instrument>,
calendar: TradingCalendar,
market_by_date: BTreeMap<NaiveDate, Vec<DailyMarketSnapshot>>,
market_index: HashMap<(NaiveDate, String), DailyMarketSnapshot>,
factor_by_date: BTreeMap<NaiveDate, Vec<DailyFactorSnapshot>>,
factor_index: HashMap<(NaiveDate, String), DailyFactorSnapshot>,
candidate_by_date: BTreeMap<NaiveDate, Vec<CandidateEligibility>>,
candidate_index: HashMap<(NaiveDate, String), CandidateEligibility>,
benchmark_by_date: BTreeMap<NaiveDate, BenchmarkSnapshot>,
benchmark_code: String,
}
impl DataSet {
pub fn from_csv_dir(path: &Path) -> Result<Self, DataSetError> {
let instruments = read_instruments(&path.join("instruments.csv"))?;
let market = read_market(&path.join("market.csv"))?;
let factors = read_factors(&path.join("factors.csv"))?;
let candidates = read_candidates(&path.join("candidate_flags.csv"))?;
let benchmarks = read_benchmarks(&path.join("benchmark.csv"))?;
let benchmark_code = collect_benchmark_code(&benchmarks)?;
let calendar = TradingCalendar::new(benchmarks.iter().map(|item| item.date).collect());
let instruments = instruments
.into_iter()
.map(|instrument| (instrument.symbol.clone(), instrument))
.collect::<HashMap<_, _>>();
let market_by_date = group_by_date(market.clone(), |item| item.date);
let market_index = market
.into_iter()
.map(|item| ((item.date, item.symbol.clone()), item))
.collect::<HashMap<_, _>>();
let factor_by_date = group_by_date(factors.clone(), |item| item.date);
let factor_index = factors
.into_iter()
.map(|item| ((item.date, item.symbol.clone()), item))
.collect::<HashMap<_, _>>();
let candidate_by_date = group_by_date(candidates.clone(), |item| item.date);
let candidate_index = candidates
.into_iter()
.map(|item| ((item.date, item.symbol.clone()), item))
.collect::<HashMap<_, _>>();
let benchmark_by_date = benchmarks
.into_iter()
.map(|item| (item.date, item))
.collect::<BTreeMap<_, _>>();
Ok(Self {
instruments,
calendar,
market_by_date,
market_index,
factor_by_date,
factor_index,
candidate_by_date,
candidate_index,
benchmark_by_date,
benchmark_code,
})
}
pub fn calendar(&self) -> &TradingCalendar {
&self.calendar
}
pub fn benchmark_code(&self) -> &str {
&self.benchmark_code
}
pub fn instruments(&self) -> &HashMap<String, Instrument> {
&self.instruments
}
pub fn market(&self, date: NaiveDate, symbol: &str) -> Option<&DailyMarketSnapshot> {
self.market_index.get(&(date, symbol.to_string()))
}
pub fn factor(&self, date: NaiveDate, symbol: &str) -> Option<&DailyFactorSnapshot> {
self.factor_index.get(&(date, symbol.to_string()))
}
pub fn candidate(&self, date: NaiveDate, symbol: &str) -> Option<&CandidateEligibility> {
self.candidate_index.get(&(date, symbol.to_string()))
}
pub fn benchmark(&self, date: NaiveDate) -> Option<&BenchmarkSnapshot> {
self.benchmark_by_date.get(&date)
}
pub fn benchmark_series(&self) -> Vec<BenchmarkSnapshot> {
self.benchmark_by_date.values().cloned().collect()
}
pub fn price(&self, date: NaiveDate, symbol: &str, field: PriceField) -> Option<f64> {
let snapshot = self.market(date, symbol)?;
Some(match field {
PriceField::Open => snapshot.open,
PriceField::Close => snapshot.close,
})
}
pub fn factor_snapshots_on(&self, date: NaiveDate) -> Vec<&DailyFactorSnapshot> {
self.factor_by_date
.get(&date)
.map(|rows| rows.iter().collect())
.unwrap_or_default()
}
pub fn market_snapshots_on(&self, date: NaiveDate) -> Vec<&DailyMarketSnapshot> {
self.market_by_date
.get(&date)
.map(|rows| rows.iter().collect())
.unwrap_or_default()
}
pub fn candidate_snapshots_on(&self, date: NaiveDate) -> Vec<&CandidateEligibility> {
self.candidate_by_date
.get(&date)
.map(|rows| rows.iter().collect())
.unwrap_or_default()
}
pub fn benchmark_closes_up_to(&self, date: NaiveDate, lookback: usize) -> Vec<f64> {
self.calendar
.trailing_days(date, lookback)
.into_iter()
.filter_map(|day| self.benchmark(day).map(|row| row.close))
.collect()
}
pub fn require_market(
&self,
date: NaiveDate,
symbol: &str,
) -> Result<&DailyMarketSnapshot, DataSetError> {
self.market(date, symbol).ok_or_else(|| DataSetError::MissingSnapshot {
kind: "market",
date,
symbol: symbol.to_string(),
})
}
pub fn require_candidate(
&self,
date: NaiveDate,
symbol: &str,
) -> Result<&CandidateEligibility, DataSetError> {
self.candidate(date, symbol)
.ok_or_else(|| DataSetError::MissingSnapshot {
kind: "candidate",
date,
symbol: symbol.to_string(),
})
}
}
fn read_instruments(path: &Path) -> Result<Vec<Instrument>, DataSetError> {
let rows = read_rows(path)?;
let mut instruments = Vec::new();
for row in rows {
instruments.push(Instrument {
symbol: row.get(0)?.to_string(),
name: row.get(1)?.to_string(),
board: row.get(2)?.to_string(),
});
}
Ok(instruments)
}
fn read_market(path: &Path) -> Result<Vec<DailyMarketSnapshot>, DataSetError> {
let rows = read_rows(path)?;
let mut snapshots = Vec::new();
for row in rows {
let prev_close = row.parse_f64(6)?;
snapshots.push(DailyMarketSnapshot {
date: row.parse_date(0)?,
symbol: row.get(1)?.to_string(),
open: row.parse_f64(2)?,
high: row.parse_f64(3)?,
low: row.parse_f64(4)?,
close: row.parse_f64(5)?,
prev_close,
volume: row.parse_u64(7)?,
paused: row.parse_bool(8)?,
upper_limit: round2(prev_close * 1.10),
lower_limit: round2(prev_close * 0.90),
});
}
Ok(snapshots)
}
fn read_factors(path: &Path) -> Result<Vec<DailyFactorSnapshot>, DataSetError> {
let rows = read_rows(path)?;
let mut snapshots = Vec::new();
for row in rows {
snapshots.push(DailyFactorSnapshot {
date: row.parse_date(0)?,
symbol: row.get(1)?.to_string(),
market_cap_bn: row.parse_f64(2)?,
free_float_cap_bn: row.parse_f64(3)?,
pe_ttm: row.parse_f64(4)?,
});
}
Ok(snapshots)
}
fn read_candidates(path: &Path) -> Result<Vec<CandidateEligibility>, DataSetError> {
let rows = read_rows(path)?;
let mut snapshots = Vec::new();
for row in rows {
snapshots.push(CandidateEligibility {
date: row.parse_date(0)?,
symbol: row.get(1)?.to_string(),
is_st: row.parse_bool(2)?,
is_new_listing: row.parse_bool(3)?,
is_paused: row.parse_bool(4)?,
allow_buy: row.parse_bool(5)?,
allow_sell: row.parse_bool(6)?,
});
}
Ok(snapshots)
}
fn read_benchmarks(path: &Path) -> Result<Vec<BenchmarkSnapshot>, DataSetError> {
let rows = read_rows(path)?;
let mut snapshots = Vec::new();
for row in rows {
snapshots.push(BenchmarkSnapshot {
date: row.parse_date(0)?,
benchmark: row.get(1)?.to_string(),
open: row.parse_f64(2)?,
close: row.parse_f64(3)?,
prev_close: row.parse_f64(4)?,
volume: row.parse_u64(5)?,
});
}
Ok(snapshots)
}
struct CsvRow {
path: String,
line: usize,
fields: Vec<String>,
}
impl CsvRow {
fn get(&self, index: usize) -> Result<&str, DataSetError> {
self.fields.get(index).map(String::as_str).ok_or_else(|| DataSetError::InvalidRow {
path: self.path.clone(),
line: self.line,
message: format!("missing column {index}"),
})
}
fn parse_date(&self, index: usize) -> Result<NaiveDate, DataSetError> {
NaiveDate::parse_from_str(self.get(index)?, "%Y-%m-%d").map_err(|err| DataSetError::InvalidRow {
path: self.path.clone(),
line: self.line,
message: format!("invalid date: {err}"),
})
}
fn parse_f64(&self, index: usize) -> Result<f64, DataSetError> {
self.get(index)?
.parse::<f64>()
.map_err(|err| DataSetError::InvalidRow {
path: self.path.clone(),
line: self.line,
message: format!("invalid f64: {err}"),
})
}
fn parse_u64(&self, index: usize) -> Result<u64, DataSetError> {
self.get(index)?
.parse::<u64>()
.map_err(|err| DataSetError::InvalidRow {
path: self.path.clone(),
line: self.line,
message: format!("invalid u64: {err}"),
})
}
fn parse_bool(&self, index: usize) -> Result<bool, DataSetError> {
self.get(index)?
.parse::<bool>()
.map_err(|err| DataSetError::InvalidRow {
path: self.path.clone(),
line: self.line,
message: format!("invalid bool: {err}"),
})
}
}
fn read_rows(path: &Path) -> Result<Vec<CsvRow>, DataSetError> {
let content = fs::read_to_string(path).map_err(|source| DataSetError::Io {
path: path.display().to_string(),
source,
})?;
let mut rows = Vec::new();
for (line_idx, line) in content.lines().enumerate() {
let line_no = line_idx + 1;
if line_no == 1 || line.trim().is_empty() {
continue;
}
rows.push(CsvRow {
path: path.display().to_string(),
line: line_no,
fields: line.split(',').map(|field| field.trim().to_string()).collect(),
});
}
Ok(rows)
}
fn group_by_date<T, F>(rows: Vec<T>, mut date_of: F) -> BTreeMap<NaiveDate, Vec<T>>
where
F: FnMut(&T) -> NaiveDate,
{
let mut grouped = BTreeMap::<NaiveDate, Vec<T>>::new();
for row in rows {
grouped.entry(date_of(&row)).or_default().push(row);
}
grouped
}
fn collect_benchmark_code(benchmarks: &[BenchmarkSnapshot]) -> Result<String, DataSetError> {
let mut codes = benchmarks
.iter()
.map(|row| row.benchmark.clone())
.collect::<Vec<_>>();
codes.sort_unstable();
codes.dedup();
if codes.len() == 1 {
Ok(codes.remove(0))
} else {
Err(DataSetError::MultipleBenchmarks)
}
}
fn round2(value: f64) -> f64 {
(value * 100.0).round() / 100.0
}