use std::collections::{BTreeMap, HashMap}; use std::fs; use std::path::Path; use chrono::NaiveDate; use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::calendar::TradingCalendar; use crate::instrument::Instrument; mod date_format { use chrono::NaiveDate; use serde::{self, Deserialize, Deserializer, Serializer}; const FORMAT: &str = "%Y-%m-%d"; pub fn serialize(date: &NaiveDate, serializer: S) -> Result where S: Serializer, { serializer.serialize_str(&date.format(FORMAT).to_string()) } pub fn deserialize<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { let text = String::deserialize(deserializer)?; NaiveDate::parse_from_str(&text, FORMAT).map_err(serde::de::Error::custom) } } #[derive(Debug, Error)] pub enum DataSetError { #[error("failed to read file {path}: {source}")] Io { path: String, #[source] source: std::io::Error, }, #[error("invalid csv row in {path} at line {line}: {message}")] InvalidRow { path: String, line: usize, message: String, }, #[error("benchmark file contains multiple benchmark codes")] MultipleBenchmarks, #[error("missing data for {kind} on {date} / {symbol}")] MissingSnapshot { kind: &'static str, date: NaiveDate, symbol: String, }, #[error("benchmark snapshot missing for {date}")] MissingBenchmark { date: NaiveDate }, } #[derive(Debug, Clone, Copy)] pub enum PriceField { Open, Close, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DailyMarketSnapshot { #[serde(with = "date_format")] pub date: NaiveDate, pub symbol: String, pub open: f64, pub high: f64, pub low: f64, pub close: f64, pub prev_close: f64, pub volume: u64, pub paused: bool, pub upper_limit: f64, pub lower_limit: f64, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DailyFactorSnapshot { #[serde(with = "date_format")] pub date: NaiveDate, pub symbol: String, pub market_cap_bn: f64, pub free_float_cap_bn: f64, pub pe_ttm: f64, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct BenchmarkSnapshot { #[serde(with = "date_format")] pub date: NaiveDate, pub benchmark: String, pub open: f64, pub close: f64, pub prev_close: f64, pub volume: u64, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CandidateEligibility { #[serde(with = "date_format")] pub date: NaiveDate, pub symbol: String, pub is_st: bool, pub is_new_listing: bool, pub is_paused: bool, pub allow_buy: bool, pub allow_sell: bool, } impl CandidateEligibility { pub fn eligible_for_selection(&self) -> bool { !self.is_st && !self.is_new_listing && !self.is_paused && self.allow_buy && self.allow_sell } } #[derive(Debug, Clone)] pub struct DataSet { instruments: HashMap, calendar: TradingCalendar, market_by_date: BTreeMap>, market_index: HashMap<(NaiveDate, String), DailyMarketSnapshot>, factor_by_date: BTreeMap>, factor_index: HashMap<(NaiveDate, String), DailyFactorSnapshot>, candidate_by_date: BTreeMap>, candidate_index: HashMap<(NaiveDate, String), CandidateEligibility>, benchmark_by_date: BTreeMap, benchmark_code: String, } impl DataSet { pub fn from_csv_dir(path: &Path) -> Result { let instruments = read_instruments(&path.join("instruments.csv"))?; let market = read_market(&path.join("market.csv"))?; let factors = read_factors(&path.join("factors.csv"))?; let candidates = read_candidates(&path.join("candidate_flags.csv"))?; let benchmarks = read_benchmarks(&path.join("benchmark.csv"))?; let benchmark_code = collect_benchmark_code(&benchmarks)?; let calendar = TradingCalendar::new(benchmarks.iter().map(|item| item.date).collect()); let instruments = instruments .into_iter() .map(|instrument| (instrument.symbol.clone(), instrument)) .collect::>(); let market_by_date = group_by_date(market.clone(), |item| item.date); let market_index = market .into_iter() .map(|item| ((item.date, item.symbol.clone()), item)) .collect::>(); let factor_by_date = group_by_date(factors.clone(), |item| item.date); let factor_index = factors .into_iter() .map(|item| ((item.date, item.symbol.clone()), item)) .collect::>(); let candidate_by_date = group_by_date(candidates.clone(), |item| item.date); let candidate_index = candidates .into_iter() .map(|item| ((item.date, item.symbol.clone()), item)) .collect::>(); let benchmark_by_date = benchmarks .into_iter() .map(|item| (item.date, item)) .collect::>(); Ok(Self { instruments, calendar, market_by_date, market_index, factor_by_date, factor_index, candidate_by_date, candidate_index, benchmark_by_date, benchmark_code, }) } pub fn calendar(&self) -> &TradingCalendar { &self.calendar } pub fn benchmark_code(&self) -> &str { &self.benchmark_code } pub fn instruments(&self) -> &HashMap { &self.instruments } pub fn market(&self, date: NaiveDate, symbol: &str) -> Option<&DailyMarketSnapshot> { self.market_index.get(&(date, symbol.to_string())) } pub fn factor(&self, date: NaiveDate, symbol: &str) -> Option<&DailyFactorSnapshot> { self.factor_index.get(&(date, symbol.to_string())) } pub fn candidate(&self, date: NaiveDate, symbol: &str) -> Option<&CandidateEligibility> { self.candidate_index.get(&(date, symbol.to_string())) } pub fn benchmark(&self, date: NaiveDate) -> Option<&BenchmarkSnapshot> { self.benchmark_by_date.get(&date) } pub fn benchmark_series(&self) -> Vec { self.benchmark_by_date.values().cloned().collect() } pub fn price(&self, date: NaiveDate, symbol: &str, field: PriceField) -> Option { let snapshot = self.market(date, symbol)?; Some(match field { PriceField::Open => snapshot.open, PriceField::Close => snapshot.close, }) } pub fn factor_snapshots_on(&self, date: NaiveDate) -> Vec<&DailyFactorSnapshot> { self.factor_by_date .get(&date) .map(|rows| rows.iter().collect()) .unwrap_or_default() } pub fn market_snapshots_on(&self, date: NaiveDate) -> Vec<&DailyMarketSnapshot> { self.market_by_date .get(&date) .map(|rows| rows.iter().collect()) .unwrap_or_default() } pub fn candidate_snapshots_on(&self, date: NaiveDate) -> Vec<&CandidateEligibility> { self.candidate_by_date .get(&date) .map(|rows| rows.iter().collect()) .unwrap_or_default() } pub fn benchmark_closes_up_to(&self, date: NaiveDate, lookback: usize) -> Vec { self.calendar .trailing_days(date, lookback) .into_iter() .filter_map(|day| self.benchmark(day).map(|row| row.close)) .collect() } pub fn require_market( &self, date: NaiveDate, symbol: &str, ) -> Result<&DailyMarketSnapshot, DataSetError> { self.market(date, symbol).ok_or_else(|| DataSetError::MissingSnapshot { kind: "market", date, symbol: symbol.to_string(), }) } pub fn require_candidate( &self, date: NaiveDate, symbol: &str, ) -> Result<&CandidateEligibility, DataSetError> { self.candidate(date, symbol) .ok_or_else(|| DataSetError::MissingSnapshot { kind: "candidate", date, symbol: symbol.to_string(), }) } } fn read_instruments(path: &Path) -> Result, DataSetError> { let rows = read_rows(path)?; let mut instruments = Vec::new(); for row in rows { instruments.push(Instrument { symbol: row.get(0)?.to_string(), name: row.get(1)?.to_string(), board: row.get(2)?.to_string(), }); } Ok(instruments) } fn read_market(path: &Path) -> Result, DataSetError> { let rows = read_rows(path)?; let mut snapshots = Vec::new(); for row in rows { let prev_close = row.parse_f64(6)?; snapshots.push(DailyMarketSnapshot { date: row.parse_date(0)?, symbol: row.get(1)?.to_string(), open: row.parse_f64(2)?, high: row.parse_f64(3)?, low: row.parse_f64(4)?, close: row.parse_f64(5)?, prev_close, volume: row.parse_u64(7)?, paused: row.parse_bool(8)?, upper_limit: round2(prev_close * 1.10), lower_limit: round2(prev_close * 0.90), }); } Ok(snapshots) } fn read_factors(path: &Path) -> Result, DataSetError> { let rows = read_rows(path)?; let mut snapshots = Vec::new(); for row in rows { snapshots.push(DailyFactorSnapshot { date: row.parse_date(0)?, symbol: row.get(1)?.to_string(), market_cap_bn: row.parse_f64(2)?, free_float_cap_bn: row.parse_f64(3)?, pe_ttm: row.parse_f64(4)?, }); } Ok(snapshots) } fn read_candidates(path: &Path) -> Result, DataSetError> { let rows = read_rows(path)?; let mut snapshots = Vec::new(); for row in rows { snapshots.push(CandidateEligibility { date: row.parse_date(0)?, symbol: row.get(1)?.to_string(), is_st: row.parse_bool(2)?, is_new_listing: row.parse_bool(3)?, is_paused: row.parse_bool(4)?, allow_buy: row.parse_bool(5)?, allow_sell: row.parse_bool(6)?, }); } Ok(snapshots) } fn read_benchmarks(path: &Path) -> Result, DataSetError> { let rows = read_rows(path)?; let mut snapshots = Vec::new(); for row in rows { snapshots.push(BenchmarkSnapshot { date: row.parse_date(0)?, benchmark: row.get(1)?.to_string(), open: row.parse_f64(2)?, close: row.parse_f64(3)?, prev_close: row.parse_f64(4)?, volume: row.parse_u64(5)?, }); } Ok(snapshots) } struct CsvRow { path: String, line: usize, fields: Vec, } impl CsvRow { fn get(&self, index: usize) -> Result<&str, DataSetError> { self.fields.get(index).map(String::as_str).ok_or_else(|| DataSetError::InvalidRow { path: self.path.clone(), line: self.line, message: format!("missing column {index}"), }) } fn parse_date(&self, index: usize) -> Result { NaiveDate::parse_from_str(self.get(index)?, "%Y-%m-%d").map_err(|err| DataSetError::InvalidRow { path: self.path.clone(), line: self.line, message: format!("invalid date: {err}"), }) } fn parse_f64(&self, index: usize) -> Result { self.get(index)? .parse::() .map_err(|err| DataSetError::InvalidRow { path: self.path.clone(), line: self.line, message: format!("invalid f64: {err}"), }) } fn parse_u64(&self, index: usize) -> Result { self.get(index)? .parse::() .map_err(|err| DataSetError::InvalidRow { path: self.path.clone(), line: self.line, message: format!("invalid u64: {err}"), }) } fn parse_bool(&self, index: usize) -> Result { self.get(index)? .parse::() .map_err(|err| DataSetError::InvalidRow { path: self.path.clone(), line: self.line, message: format!("invalid bool: {err}"), }) } } fn read_rows(path: &Path) -> Result, DataSetError> { let content = fs::read_to_string(path).map_err(|source| DataSetError::Io { path: path.display().to_string(), source, })?; let mut rows = Vec::new(); for (line_idx, line) in content.lines().enumerate() { let line_no = line_idx + 1; if line_no == 1 || line.trim().is_empty() { continue; } rows.push(CsvRow { path: path.display().to_string(), line: line_no, fields: line.split(',').map(|field| field.trim().to_string()).collect(), }); } Ok(rows) } fn group_by_date(rows: Vec, mut date_of: F) -> BTreeMap> where F: FnMut(&T) -> NaiveDate, { let mut grouped = BTreeMap::>::new(); for row in rows { grouped.entry(date_of(&row)).or_default().push(row); } grouped } fn collect_benchmark_code(benchmarks: &[BenchmarkSnapshot]) -> Result { let mut codes = benchmarks .iter() .map(|row| row.benchmark.clone()) .collect::>(); codes.sort_unstable(); codes.dedup(); if codes.len() == 1 { Ok(codes.remove(0)) } else { Err(DataSetError::MultipleBenchmarks) } } fn round2(value: f64) -> f64 { (value * 100.0).round() / 100.0 }