Add string factor support

This commit is contained in:
boris
2026-04-23 23:05:43 -07:00
parent 0b0b9333fa
commit 47988cd7e7
6 changed files with 587 additions and 31 deletions

View File

@@ -411,6 +411,15 @@ pub struct FactorValue {
pub value: f64,
}
#[derive(Debug, Clone, Serialize)]
pub struct FactorTextValue {
#[serde(with = "date_format")]
pub date: NaiveDate,
pub symbol: String,
pub field: String,
pub value: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct SecuritiesMarginRecord {
#[serde(with = "date_format")]
@@ -694,6 +703,8 @@ pub struct DataSet {
market_index: HashMap<(NaiveDate, String), DailyMarketSnapshot>,
factor_by_date: BTreeMap<NaiveDate, Vec<DailyFactorSnapshot>>,
factor_index: HashMap<(NaiveDate, String), DailyFactorSnapshot>,
factor_text_by_date: BTreeMap<NaiveDate, Vec<FactorTextValue>>,
factor_text_index: HashMap<(NaiveDate, String, String), FactorTextValue>,
candidate_by_date: BTreeMap<NaiveDate, Vec<CandidateEligibility>>,
candidate_index: HashMap<(NaiveDate, String), CandidateEligibility>,
corporate_actions_by_date: BTreeMap<NaiveDate, Vec<CorporateAction>>,
@@ -712,6 +723,7 @@ impl DataSet {
let instruments = read_instruments(&path.join("instruments.csv"))?;
let market = read_market(&path.join("market.csv"))?;
let factors = read_factors(&path.join("factors.csv"))?;
let factor_texts = read_factor_texts(&path.join("factors.csv"))?;
let candidates = read_candidates(&path.join("candidate_flags.csv"))?;
let benchmarks = read_benchmarks(&path.join("benchmark.csv"))?;
let corporate_actions_path = path.join("corporate_actions.csv");
@@ -738,7 +750,7 @@ impl DataSet {
} else {
Vec::new()
};
Self::from_components_with_actions_quotes_futures_and_depth(
Self::from_components_with_actions_quotes_futures_depth_and_factor_texts(
instruments,
market,
factors,
@@ -748,6 +760,7 @@ impl DataSet {
execution_quotes,
futures_params,
order_book_depth,
factor_texts,
)
}
@@ -756,6 +769,7 @@ impl DataSet {
let benchmarks = read_partitioned_dir(&path.join("benchmark"), read_benchmarks)?;
let market = read_partitioned_dir(&path.join("market"), read_market)?;
let factors = read_partitioned_dir(&path.join("factors"), read_factors)?;
let factor_texts = read_partitioned_dir(&path.join("factors"), read_factor_texts)?;
let candidates = read_partitioned_dir(&path.join("candidates"), read_candidates)?;
let corporate_actions_dir = path.join("corporate_actions");
let corporate_actions = if corporate_actions_dir.exists() {
@@ -781,7 +795,7 @@ impl DataSet {
} else {
Vec::new()
};
Self::from_components_with_actions_quotes_futures_and_depth(
Self::from_components_with_actions_quotes_futures_depth_and_factor_texts(
instruments,
market,
factors,
@@ -791,6 +805,7 @@ impl DataSet {
execution_quotes,
futures_params,
order_book_depth,
factor_texts,
)
}
@@ -885,6 +900,54 @@ impl DataSet {
execution_quotes: Vec<IntradayExecutionQuote>,
futures_params: Vec<FuturesTradingParameter>,
order_book_depth: Vec<IntradayOrderBookDepthLevel>,
) -> Result<Self, DataSetError> {
Self::from_components_with_actions_quotes_futures_depth_and_factor_texts(
instruments,
market,
factors,
candidates,
benchmarks,
corporate_actions,
execution_quotes,
futures_params,
order_book_depth,
Vec::new(),
)
}
pub fn from_components_with_factor_texts(
instruments: Vec<Instrument>,
market: Vec<DailyMarketSnapshot>,
factors: Vec<DailyFactorSnapshot>,
candidates: Vec<CandidateEligibility>,
benchmarks: Vec<BenchmarkSnapshot>,
factor_texts: Vec<FactorTextValue>,
) -> Result<Self, DataSetError> {
Self::from_components_with_actions_quotes_futures_depth_and_factor_texts(
instruments,
market,
factors,
candidates,
benchmarks,
Vec::new(),
Vec::new(),
Vec::new(),
Vec::new(),
factor_texts,
)
}
pub fn from_components_with_actions_quotes_futures_depth_and_factor_texts(
instruments: Vec<Instrument>,
market: Vec<DailyMarketSnapshot>,
factors: Vec<DailyFactorSnapshot>,
candidates: Vec<CandidateEligibility>,
benchmarks: Vec<BenchmarkSnapshot>,
corporate_actions: Vec<CorporateAction>,
execution_quotes: Vec<IntradayExecutionQuote>,
futures_params: Vec<FuturesTradingParameter>,
order_book_depth: Vec<IntradayOrderBookDepthLevel>,
factor_texts: Vec<FactorTextValue>,
) -> Result<Self, DataSetError> {
let benchmark_code = collect_benchmark_code(&benchmarks)?;
let calendar = TradingCalendar::new(benchmarks.iter().map(|item| item.date).collect());
@@ -905,6 +968,22 @@ impl DataSet {
.into_iter()
.map(|item| ((item.date, item.symbol.clone()), item))
.collect::<HashMap<_, _>>();
let factor_texts = factor_texts
.into_iter()
.filter_map(|mut item| {
item.field = normalize_field(&item.field);
if item.field.is_empty() {
None
} else {
Some(item)
}
})
.collect::<Vec<_>>();
let factor_text_by_date = group_by_date(factor_texts.clone(), |item| item.date);
let factor_text_index = factor_texts
.into_iter()
.map(|item| ((item.date, item.symbol.clone(), item.field.clone()), item))
.collect::<HashMap<_, _>>();
let candidate_by_date = group_by_date(candidates.clone(), |item| item.date);
let candidate_index = candidates
@@ -933,6 +1012,8 @@ impl DataSet {
market_index,
factor_by_date,
factor_index,
factor_text_by_date,
factor_text_index,
candidate_by_date,
candidate_index,
corporate_actions_by_date,
@@ -1271,6 +1352,30 @@ impl DataSet {
rows
}
pub fn get_factor_text(
&self,
symbol: &str,
start: NaiveDate,
end: NaiveDate,
field: &str,
) -> Vec<FactorTextValue> {
if start > end {
return Vec::new();
}
let field = normalize_field(field);
let mut rows = self
.factor_text_by_date
.range(start..=end)
.flat_map(|(_, snapshots)| snapshots.iter())
.filter(|snapshot| {
snapshot.symbol == symbol && normalize_field(&snapshot.field) == field
})
.cloned()
.collect::<Vec<_>>();
rows.sort_by_key(|row| row.date);
rows
}
pub fn get_yield_curve(
&self,
start: NaiveDate,
@@ -1555,6 +1660,33 @@ impl DataSet {
None
}
pub fn get_industry_name(
&self,
symbol: &str,
date: NaiveDate,
source: &str,
level: usize,
) -> Option<FactorTextValue> {
let fields = industry_name_factor_aliases(source, level);
for (factor_date, snapshots) in self.factor_text_by_date.range(..=date).rev() {
for snapshot in snapshots {
if snapshot.symbol != symbol {
continue;
}
let normalized = normalize_field(&snapshot.field);
if fields.iter().any(|field| field == &normalized) {
return Some(FactorTextValue {
date: *factor_date,
symbol: snapshot.symbol.clone(),
field: snapshot.field.clone(),
value: snapshot.value.clone(),
});
}
}
}
None
}
pub fn get_dominant_future(&self, underlying_symbol: &str, date: NaiveDate) -> Option<String> {
let underlying = normalize_field(underlying_symbol);
let mut candidates = self
@@ -1656,6 +1788,13 @@ impl DataSet {
.unwrap_or_default()
}
pub fn factor_text_snapshots_on(&self, date: NaiveDate) -> Vec<&FactorTextValue> {
self.factor_text_by_date
.get(&date)
.map(|rows| rows.iter().collect())
.unwrap_or_default()
}
pub fn market_snapshots_on(&self, date: NaiveDate) -> Vec<&DailyMarketSnapshot> {
self.market_by_date
.get(&date)
@@ -1796,6 +1935,12 @@ impl DataSet {
.and_then(|snapshot| factor_numeric_value(snapshot, field))
}
pub fn factor_text_value(&self, date: NaiveDate, symbol: &str, field: &str) -> Option<String> {
self.factor_text_index
.get(&(date, symbol.to_string(), normalize_field(field)))
.map(|row| row.value.clone())
}
fn get_first_available_factor_series(
&self,
symbol: &str,
@@ -2034,6 +2179,7 @@ fn read_factors(path: &Path) -> Result<Vec<DailyFactorSnapshot>, DataSetError> {
let rows = read_rows(path)?;
let mut snapshots = Vec::new();
for row in rows {
let (extra_factors, _) = parse_extra_factor_maps(&row);
snapshots.push(DailyFactorSnapshot {
date: row.parse_date(0)?,
symbol: row.get(1)?.to_string(),
@@ -2042,17 +2188,76 @@ fn read_factors(path: &Path) -> Result<Vec<DailyFactorSnapshot>, DataSetError> {
pe_ttm: row.parse_f64(4)?,
turnover_ratio: row.parse_optional_f64(5),
effective_turnover_ratio: row.parse_optional_f64(6),
extra_factors: row
.fields
.get(7)
.filter(|value| !value.trim().is_empty())
.and_then(|value| serde_json::from_str::<BTreeMap<String, f64>>(value).ok())
.unwrap_or_default(),
extra_factors,
});
}
Ok(snapshots)
}
fn read_factor_texts(path: &Path) -> Result<Vec<FactorTextValue>, DataSetError> {
let rows = read_rows(path)?;
let mut text_values = Vec::new();
for row in rows {
let date = row.parse_date(0)?;
let symbol = row.get(1)?.to_string();
let (_, extra_text_factors) = parse_extra_factor_maps(&row);
for (field, value) in extra_text_factors {
text_values.push(FactorTextValue {
date,
symbol: symbol.clone(),
field,
value,
});
}
}
Ok(text_values)
}
fn parse_extra_factor_maps(row: &CsvRow) -> (BTreeMap<String, f64>, BTreeMap<String, String>) {
let mut numeric = BTreeMap::new();
let mut text = BTreeMap::new();
for value in row.fields.get(7).into_iter().chain(row.fields.get(8)) {
merge_extra_factor_json(value, &mut numeric, &mut text);
}
(numeric, text)
}
fn merge_extra_factor_json(
raw: &str,
numeric: &mut BTreeMap<String, f64>,
text: &mut BTreeMap<String, String>,
) {
let trimmed = raw.trim();
if trimmed.is_empty() {
return;
}
let Ok(serde_json::Value::Object(map)) = serde_json::from_str::<serde_json::Value>(trimmed)
else {
return;
};
for (key, value) in map {
let key = normalize_field(&key);
if key.is_empty() {
continue;
}
match value {
serde_json::Value::Number(number) => {
if let Some(value) = number.as_f64().filter(|value| value.is_finite()) {
numeric.insert(key, value);
}
}
serde_json::Value::String(value) => {
text.insert(key, value);
}
serde_json::Value::Bool(value) => {
numeric.insert(key.clone(), if value { 1.0 } else { 0.0 });
text.insert(key, value.to_string());
}
_ => {}
}
}
}
fn normalized_aliases(values: &[String]) -> Vec<String> {
let mut aliases = Vec::new();
for value in values {
@@ -2191,6 +2396,21 @@ fn industry_factor_aliases(source: &str, level: usize) -> Vec<String> {
])
}
fn industry_name_factor_aliases(source: &str, level: usize) -> Vec<String> {
let source = normalize_field(source);
normalized_aliases(&[
format!("industry_{source}_l{level}_name"),
format!("industry_{source}_{level}_name"),
format!("industry_{source}_name_l{level}"),
format!("{source}_industry_l{level}_name"),
format!("{source}_industry_{level}_name"),
format!("{source}_industry_name_l{level}"),
format!("industry_l{level}_name"),
format!("industry_{level}_name"),
"industry_name".to_string(),
])
}
fn factor_numeric_value(snapshot: &DailyFactorSnapshot, field: &str) -> Option<f64> {
match field {
"market_cap" | "market_cap_bn" => Some(snapshot.market_cap_bn),
@@ -2653,16 +2873,39 @@ fn read_rows(path: &Path) -> Result<Vec<CsvRow>, DataSetError> {
rows.push(CsvRow {
path: path.display().to_string(),
line: line_no,
fields: line
.split(',')
.map(|field| field.trim().to_string())
.collect(),
fields: split_csv_line(line),
});
}
Ok(rows)
}
fn split_csv_line(line: &str) -> Vec<String> {
let mut fields = Vec::new();
let mut field = String::new();
let mut chars = line.trim_start_matches('\u{feff}').chars().peekable();
let mut in_quotes = false;
while let Some(ch) = chars.next() {
match ch {
'"' if in_quotes && chars.peek() == Some(&'"') => {
field.push('"');
chars.next();
}
'"' => {
in_quotes = !in_quotes;
}
',' if !in_quotes => {
fields.push(field.trim().to_string());
field.clear();
}
_ => field.push(ch),
}
}
fields.push(field.trim().to_string());
fields
}
fn group_by_date<T, F>(rows: Vec<T>, mut date_of: F) -> BTreeMap<NaiveDate, Vec<T>>
where
F: FnMut(&T) -> NaiveDate,
@@ -2854,3 +3097,52 @@ fn build_eligible_universe(
per_date
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::{SystemTime, UNIX_EPOCH};
fn temp_csv_path(name: &str) -> std::path::PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
std::env::temp_dir().join(format!("{}_{}_{}.csv", name, std::process::id(), nanos))
}
#[test]
fn reads_mixed_numeric_and_text_extra_factors_from_quoted_csv_json() {
let path = temp_csv_path("mixed_factor_maps");
fs::write(
&path,
concat!(
"date,symbol,market_cap_bn,free_float_cap_bn,pe_ttm,turnover_ratio,effective_turnover_ratio,extra_factors\n",
"2025-01-02,000001.SZ,12,10,8,1,1,\"{\"\"custom_alpha\"\":7,\"\"industry_name\"\":\"\"electronics,hardware\"\",\"\"flag\"\":true}\"\n"
),
)
.unwrap();
let factors = read_factors(&path).unwrap();
let text_factors = read_factor_texts(&path).unwrap();
fs::remove_file(&path).ok();
assert_eq!(factors.len(), 1);
assert_eq!(
factors[0].extra_factors.get("custom_alpha").copied(),
Some(7.0)
);
assert_eq!(factors[0].extra_factors.get("flag").copied(), Some(1.0));
assert_eq!(text_factors.len(), 2);
assert!(
text_factors
.iter()
.any(|row| row.field == "industry_name" && row.value == "electronics,hardware")
);
assert!(
text_factors
.iter()
.any(|row| row.field == "flag" && row.value == "true")
);
}
}