472 lines
14 KiB
Rust
472 lines
14 KiB
Rust
use std::collections::{BTreeMap, HashMap};
|
|
use std::fs;
|
|
use std::path::Path;
|
|
|
|
use chrono::NaiveDate;
|
|
use serde::{Deserialize, Serialize};
|
|
use thiserror::Error;
|
|
|
|
use crate::calendar::TradingCalendar;
|
|
use crate::instrument::Instrument;
|
|
|
|
mod date_format {
|
|
use chrono::NaiveDate;
|
|
use serde::{self, Deserialize, Deserializer, Serializer};
|
|
|
|
const FORMAT: &str = "%Y-%m-%d";
|
|
|
|
pub fn serialize<S>(date: &NaiveDate, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: Serializer,
|
|
{
|
|
serializer.serialize_str(&date.format(FORMAT).to_string())
|
|
}
|
|
|
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<NaiveDate, D::Error>
|
|
where
|
|
D: Deserializer<'de>,
|
|
{
|
|
let text = String::deserialize(deserializer)?;
|
|
NaiveDate::parse_from_str(&text, FORMAT).map_err(serde::de::Error::custom)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum DataSetError {
|
|
#[error("failed to read file {path}: {source}")]
|
|
Io {
|
|
path: String,
|
|
#[source]
|
|
source: std::io::Error,
|
|
},
|
|
#[error("invalid csv row in {path} at line {line}: {message}")]
|
|
InvalidRow {
|
|
path: String,
|
|
line: usize,
|
|
message: String,
|
|
},
|
|
#[error("benchmark file contains multiple benchmark codes")]
|
|
MultipleBenchmarks,
|
|
#[error("missing data for {kind} on {date} / {symbol}")]
|
|
MissingSnapshot {
|
|
kind: &'static str,
|
|
date: NaiveDate,
|
|
symbol: String,
|
|
},
|
|
#[error("benchmark snapshot missing for {date}")]
|
|
MissingBenchmark { date: NaiveDate },
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub enum PriceField {
|
|
Open,
|
|
Close,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct DailyMarketSnapshot {
|
|
#[serde(with = "date_format")]
|
|
pub date: NaiveDate,
|
|
pub symbol: String,
|
|
pub open: f64,
|
|
pub high: f64,
|
|
pub low: f64,
|
|
pub close: f64,
|
|
pub prev_close: f64,
|
|
pub volume: u64,
|
|
pub paused: bool,
|
|
pub upper_limit: f64,
|
|
pub lower_limit: f64,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct DailyFactorSnapshot {
|
|
#[serde(with = "date_format")]
|
|
pub date: NaiveDate,
|
|
pub symbol: String,
|
|
pub market_cap_bn: f64,
|
|
pub free_float_cap_bn: f64,
|
|
pub pe_ttm: f64,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct BenchmarkSnapshot {
|
|
#[serde(with = "date_format")]
|
|
pub date: NaiveDate,
|
|
pub benchmark: String,
|
|
pub open: f64,
|
|
pub close: f64,
|
|
pub prev_close: f64,
|
|
pub volume: u64,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct CandidateEligibility {
|
|
#[serde(with = "date_format")]
|
|
pub date: NaiveDate,
|
|
pub symbol: String,
|
|
pub is_st: bool,
|
|
pub is_new_listing: bool,
|
|
pub is_paused: bool,
|
|
pub allow_buy: bool,
|
|
pub allow_sell: bool,
|
|
}
|
|
|
|
impl CandidateEligibility {
|
|
pub fn eligible_for_selection(&self) -> bool {
|
|
!self.is_st && !self.is_new_listing && !self.is_paused && self.allow_buy && self.allow_sell
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct DataSet {
|
|
instruments: HashMap<String, Instrument>,
|
|
calendar: TradingCalendar,
|
|
market_by_date: BTreeMap<NaiveDate, Vec<DailyMarketSnapshot>>,
|
|
market_index: HashMap<(NaiveDate, String), DailyMarketSnapshot>,
|
|
factor_by_date: BTreeMap<NaiveDate, Vec<DailyFactorSnapshot>>,
|
|
factor_index: HashMap<(NaiveDate, String), DailyFactorSnapshot>,
|
|
candidate_by_date: BTreeMap<NaiveDate, Vec<CandidateEligibility>>,
|
|
candidate_index: HashMap<(NaiveDate, String), CandidateEligibility>,
|
|
benchmark_by_date: BTreeMap<NaiveDate, BenchmarkSnapshot>,
|
|
benchmark_code: String,
|
|
}
|
|
|
|
impl DataSet {
|
|
pub fn from_csv_dir(path: &Path) -> Result<Self, DataSetError> {
|
|
let instruments = read_instruments(&path.join("instruments.csv"))?;
|
|
let market = read_market(&path.join("market.csv"))?;
|
|
let factors = read_factors(&path.join("factors.csv"))?;
|
|
let candidates = read_candidates(&path.join("candidate_flags.csv"))?;
|
|
let benchmarks = read_benchmarks(&path.join("benchmark.csv"))?;
|
|
|
|
let benchmark_code = collect_benchmark_code(&benchmarks)?;
|
|
let calendar = TradingCalendar::new(benchmarks.iter().map(|item| item.date).collect());
|
|
|
|
let instruments = instruments
|
|
.into_iter()
|
|
.map(|instrument| (instrument.symbol.clone(), instrument))
|
|
.collect::<HashMap<_, _>>();
|
|
|
|
let market_by_date = group_by_date(market.clone(), |item| item.date);
|
|
let market_index = market
|
|
.into_iter()
|
|
.map(|item| ((item.date, item.symbol.clone()), item))
|
|
.collect::<HashMap<_, _>>();
|
|
|
|
let factor_by_date = group_by_date(factors.clone(), |item| item.date);
|
|
let factor_index = factors
|
|
.into_iter()
|
|
.map(|item| ((item.date, item.symbol.clone()), item))
|
|
.collect::<HashMap<_, _>>();
|
|
|
|
let candidate_by_date = group_by_date(candidates.clone(), |item| item.date);
|
|
let candidate_index = candidates
|
|
.into_iter()
|
|
.map(|item| ((item.date, item.symbol.clone()), item))
|
|
.collect::<HashMap<_, _>>();
|
|
|
|
let benchmark_by_date = benchmarks
|
|
.into_iter()
|
|
.map(|item| (item.date, item))
|
|
.collect::<BTreeMap<_, _>>();
|
|
|
|
Ok(Self {
|
|
instruments,
|
|
calendar,
|
|
market_by_date,
|
|
market_index,
|
|
factor_by_date,
|
|
factor_index,
|
|
candidate_by_date,
|
|
candidate_index,
|
|
benchmark_by_date,
|
|
benchmark_code,
|
|
})
|
|
}
|
|
|
|
pub fn calendar(&self) -> &TradingCalendar {
|
|
&self.calendar
|
|
}
|
|
|
|
pub fn benchmark_code(&self) -> &str {
|
|
&self.benchmark_code
|
|
}
|
|
|
|
pub fn instruments(&self) -> &HashMap<String, Instrument> {
|
|
&self.instruments
|
|
}
|
|
|
|
pub fn market(&self, date: NaiveDate, symbol: &str) -> Option<&DailyMarketSnapshot> {
|
|
self.market_index.get(&(date, symbol.to_string()))
|
|
}
|
|
|
|
pub fn factor(&self, date: NaiveDate, symbol: &str) -> Option<&DailyFactorSnapshot> {
|
|
self.factor_index.get(&(date, symbol.to_string()))
|
|
}
|
|
|
|
pub fn candidate(&self, date: NaiveDate, symbol: &str) -> Option<&CandidateEligibility> {
|
|
self.candidate_index.get(&(date, symbol.to_string()))
|
|
}
|
|
|
|
pub fn benchmark(&self, date: NaiveDate) -> Option<&BenchmarkSnapshot> {
|
|
self.benchmark_by_date.get(&date)
|
|
}
|
|
|
|
pub fn benchmark_series(&self) -> Vec<BenchmarkSnapshot> {
|
|
self.benchmark_by_date.values().cloned().collect()
|
|
}
|
|
|
|
pub fn price(&self, date: NaiveDate, symbol: &str, field: PriceField) -> Option<f64> {
|
|
let snapshot = self.market(date, symbol)?;
|
|
Some(match field {
|
|
PriceField::Open => snapshot.open,
|
|
PriceField::Close => snapshot.close,
|
|
})
|
|
}
|
|
|
|
pub fn factor_snapshots_on(&self, date: NaiveDate) -> Vec<&DailyFactorSnapshot> {
|
|
self.factor_by_date
|
|
.get(&date)
|
|
.map(|rows| rows.iter().collect())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
pub fn market_snapshots_on(&self, date: NaiveDate) -> Vec<&DailyMarketSnapshot> {
|
|
self.market_by_date
|
|
.get(&date)
|
|
.map(|rows| rows.iter().collect())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
pub fn candidate_snapshots_on(&self, date: NaiveDate) -> Vec<&CandidateEligibility> {
|
|
self.candidate_by_date
|
|
.get(&date)
|
|
.map(|rows| rows.iter().collect())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
pub fn benchmark_closes_up_to(&self, date: NaiveDate, lookback: usize) -> Vec<f64> {
|
|
self.calendar
|
|
.trailing_days(date, lookback)
|
|
.into_iter()
|
|
.filter_map(|day| self.benchmark(day).map(|row| row.close))
|
|
.collect()
|
|
}
|
|
|
|
pub fn require_market(
|
|
&self,
|
|
date: NaiveDate,
|
|
symbol: &str,
|
|
) -> Result<&DailyMarketSnapshot, DataSetError> {
|
|
self.market(date, symbol).ok_or_else(|| DataSetError::MissingSnapshot {
|
|
kind: "market",
|
|
date,
|
|
symbol: symbol.to_string(),
|
|
})
|
|
}
|
|
|
|
pub fn require_candidate(
|
|
&self,
|
|
date: NaiveDate,
|
|
symbol: &str,
|
|
) -> Result<&CandidateEligibility, DataSetError> {
|
|
self.candidate(date, symbol)
|
|
.ok_or_else(|| DataSetError::MissingSnapshot {
|
|
kind: "candidate",
|
|
date,
|
|
symbol: symbol.to_string(),
|
|
})
|
|
}
|
|
}
|
|
|
|
fn read_instruments(path: &Path) -> Result<Vec<Instrument>, DataSetError> {
|
|
let rows = read_rows(path)?;
|
|
let mut instruments = Vec::new();
|
|
for row in rows {
|
|
instruments.push(Instrument {
|
|
symbol: row.get(0)?.to_string(),
|
|
name: row.get(1)?.to_string(),
|
|
board: row.get(2)?.to_string(),
|
|
});
|
|
}
|
|
Ok(instruments)
|
|
}
|
|
|
|
fn read_market(path: &Path) -> Result<Vec<DailyMarketSnapshot>, DataSetError> {
|
|
let rows = read_rows(path)?;
|
|
let mut snapshots = Vec::new();
|
|
for row in rows {
|
|
let prev_close = row.parse_f64(6)?;
|
|
snapshots.push(DailyMarketSnapshot {
|
|
date: row.parse_date(0)?,
|
|
symbol: row.get(1)?.to_string(),
|
|
open: row.parse_f64(2)?,
|
|
high: row.parse_f64(3)?,
|
|
low: row.parse_f64(4)?,
|
|
close: row.parse_f64(5)?,
|
|
prev_close,
|
|
volume: row.parse_u64(7)?,
|
|
paused: row.parse_bool(8)?,
|
|
upper_limit: round2(prev_close * 1.10),
|
|
lower_limit: round2(prev_close * 0.90),
|
|
});
|
|
}
|
|
Ok(snapshots)
|
|
}
|
|
|
|
fn read_factors(path: &Path) -> Result<Vec<DailyFactorSnapshot>, DataSetError> {
|
|
let rows = read_rows(path)?;
|
|
let mut snapshots = Vec::new();
|
|
for row in rows {
|
|
snapshots.push(DailyFactorSnapshot {
|
|
date: row.parse_date(0)?,
|
|
symbol: row.get(1)?.to_string(),
|
|
market_cap_bn: row.parse_f64(2)?,
|
|
free_float_cap_bn: row.parse_f64(3)?,
|
|
pe_ttm: row.parse_f64(4)?,
|
|
});
|
|
}
|
|
Ok(snapshots)
|
|
}
|
|
|
|
fn read_candidates(path: &Path) -> Result<Vec<CandidateEligibility>, DataSetError> {
|
|
let rows = read_rows(path)?;
|
|
let mut snapshots = Vec::new();
|
|
for row in rows {
|
|
snapshots.push(CandidateEligibility {
|
|
date: row.parse_date(0)?,
|
|
symbol: row.get(1)?.to_string(),
|
|
is_st: row.parse_bool(2)?,
|
|
is_new_listing: row.parse_bool(3)?,
|
|
is_paused: row.parse_bool(4)?,
|
|
allow_buy: row.parse_bool(5)?,
|
|
allow_sell: row.parse_bool(6)?,
|
|
});
|
|
}
|
|
Ok(snapshots)
|
|
}
|
|
|
|
fn read_benchmarks(path: &Path) -> Result<Vec<BenchmarkSnapshot>, DataSetError> {
|
|
let rows = read_rows(path)?;
|
|
let mut snapshots = Vec::new();
|
|
for row in rows {
|
|
snapshots.push(BenchmarkSnapshot {
|
|
date: row.parse_date(0)?,
|
|
benchmark: row.get(1)?.to_string(),
|
|
open: row.parse_f64(2)?,
|
|
close: row.parse_f64(3)?,
|
|
prev_close: row.parse_f64(4)?,
|
|
volume: row.parse_u64(5)?,
|
|
});
|
|
}
|
|
Ok(snapshots)
|
|
}
|
|
|
|
struct CsvRow {
|
|
path: String,
|
|
line: usize,
|
|
fields: Vec<String>,
|
|
}
|
|
|
|
impl CsvRow {
|
|
fn get(&self, index: usize) -> Result<&str, DataSetError> {
|
|
self.fields.get(index).map(String::as_str).ok_or_else(|| DataSetError::InvalidRow {
|
|
path: self.path.clone(),
|
|
line: self.line,
|
|
message: format!("missing column {index}"),
|
|
})
|
|
}
|
|
|
|
fn parse_date(&self, index: usize) -> Result<NaiveDate, DataSetError> {
|
|
NaiveDate::parse_from_str(self.get(index)?, "%Y-%m-%d").map_err(|err| DataSetError::InvalidRow {
|
|
path: self.path.clone(),
|
|
line: self.line,
|
|
message: format!("invalid date: {err}"),
|
|
})
|
|
}
|
|
|
|
fn parse_f64(&self, index: usize) -> Result<f64, DataSetError> {
|
|
self.get(index)?
|
|
.parse::<f64>()
|
|
.map_err(|err| DataSetError::InvalidRow {
|
|
path: self.path.clone(),
|
|
line: self.line,
|
|
message: format!("invalid f64: {err}"),
|
|
})
|
|
}
|
|
|
|
fn parse_u64(&self, index: usize) -> Result<u64, DataSetError> {
|
|
self.get(index)?
|
|
.parse::<u64>()
|
|
.map_err(|err| DataSetError::InvalidRow {
|
|
path: self.path.clone(),
|
|
line: self.line,
|
|
message: format!("invalid u64: {err}"),
|
|
})
|
|
}
|
|
|
|
fn parse_bool(&self, index: usize) -> Result<bool, DataSetError> {
|
|
self.get(index)?
|
|
.parse::<bool>()
|
|
.map_err(|err| DataSetError::InvalidRow {
|
|
path: self.path.clone(),
|
|
line: self.line,
|
|
message: format!("invalid bool: {err}"),
|
|
})
|
|
}
|
|
}
|
|
|
|
fn read_rows(path: &Path) -> Result<Vec<CsvRow>, DataSetError> {
|
|
let content = fs::read_to_string(path).map_err(|source| DataSetError::Io {
|
|
path: path.display().to_string(),
|
|
source,
|
|
})?;
|
|
|
|
let mut rows = Vec::new();
|
|
for (line_idx, line) in content.lines().enumerate() {
|
|
let line_no = line_idx + 1;
|
|
if line_no == 1 || line.trim().is_empty() {
|
|
continue;
|
|
}
|
|
|
|
rows.push(CsvRow {
|
|
path: path.display().to_string(),
|
|
line: line_no,
|
|
fields: line.split(',').map(|field| field.trim().to_string()).collect(),
|
|
});
|
|
}
|
|
|
|
Ok(rows)
|
|
}
|
|
|
|
fn group_by_date<T, F>(rows: Vec<T>, mut date_of: F) -> BTreeMap<NaiveDate, Vec<T>>
|
|
where
|
|
F: FnMut(&T) -> NaiveDate,
|
|
{
|
|
let mut grouped = BTreeMap::<NaiveDate, Vec<T>>::new();
|
|
for row in rows {
|
|
grouped.entry(date_of(&row)).or_default().push(row);
|
|
}
|
|
grouped
|
|
}
|
|
|
|
fn collect_benchmark_code(benchmarks: &[BenchmarkSnapshot]) -> Result<String, DataSetError> {
|
|
let mut codes = benchmarks
|
|
.iter()
|
|
.map(|row| row.benchmark.clone())
|
|
.collect::<Vec<_>>();
|
|
codes.sort_unstable();
|
|
codes.dedup();
|
|
|
|
if codes.len() == 1 {
|
|
Ok(codes.remove(0))
|
|
} else {
|
|
Err(DataSetError::MultipleBenchmarks)
|
|
}
|
|
}
|
|
|
|
fn round2(value: f64) -> f64 {
|
|
(value * 100.0).round() / 100.0
|
|
}
|