abacus/abacus-core/src/dataframe.rs

694 lines
22 KiB
Rust

use super::ScriptResult;
use polars::prelude::{IntoLazy, NamedFrom, SerReader};
use std::{
any::TypeId,
ops::{Add, Deref},
str::FromStr,
};
use rhai::{Dynamic, EvalAltResult, EvalContext, Expression, Position};
pub fn setup_engine(engine: &mut rhai::Engine) {
//polar data frame
engine.register_type::<DataFrame>();
engine.register_fn("dataframe", script_functions::dataframe);
engine.register_fn("select", DataFrame::s_select);
engine.register_fn("load_csv", DataFrame::load_csv);
engine.register_fn("column", DataFrame::s_column);
engine.register_indexer_get(DataFrame::s_column);
engine.register_indexer_set(DataFrame::append_column);
// polar series
engine.register_type::<Series>();
engine.register_indexer_get(Series::s_get);
engine.register_fn("series", script_functions::series);
engine.register_fn("to_series", script_functions::to_series);
engine.register_fn("to_series", script_functions::to_series_unnamed);
engine.register_fn("series", script_functions::series_unnamed);
engine.register_fn("head", Series::s_head);
engine.register_fn("sort", Series::s_sort);
engine.register_fn("sum", Series::s_sum);
engine.register_fn("add", Series::s_op_add);
engine.register_fn("+", Series::add);
engine.register_fn("+", script_functions::add_series_i64);
engine.register_fn("+", script_functions::add_series_f64);
engine.register_fn("-", script_functions::subtract_series_series);
engine.register_fn("-", script_functions::subtract_series_i64);
engine.register_fn("-", script_functions::subtract_series_f64);
engine.register_fn("**", script_functions::power_series_f64);
engine.register_fn("**", script_functions::power_series_i64);
engine.register_fn("*", script_functions::multiply_series_series);
engine.register_fn("*", script_functions::multiply_series_i64);
engine.register_fn("*", script_functions::multiply_series_f64);
engine.register_fn("/", script_functions::div_series_series);
engine.register_fn("/", script_functions::div_series_i64);
engine.register_fn("/", script_functions::div_series_f64);
// polar expressions
engine.register_type::<DataFrameExpression>();
engine.register_fn("sum", DataFrameExpression::sum);
engine.register_fn("mean", DataFrameExpression::mean);
engine.register_fn("min", DataFrameExpression::min);
engine.register_fn("max", DataFrameExpression::max);
engine.register_fn("first", DataFrameExpression::first);
engine.register_fn("count", DataFrameExpression::count);
engine.register_fn("eq", DataFrameExpression::eq);
engine.register_fn("lt", DataFrameExpression::lt);
engine.register_fn("lte", DataFrameExpression::lte);
engine.register_fn("gt", DataFrameExpression::gt);
engine.register_fn("gte", DataFrameExpression::gte);
engine.register_fn("filter", DataFrameExpression::filter);
engine.register_fn("column", script_functions::column);
engine.register_fn("count", script_functions::count);
engine.register_fn("sum", script_functions::sum);
engine.register_fn("mean", script_functions::mean);
engine.register_fn("min", script_functions::min);
engine.register_fn("max", script_functions::max);
engine.register_fn("first", script_functions::first);
let _ = engine.register_custom_operator("gt", 200);
let _ = engine.register_custom_operator("gte", 200);
let _ = engine.register_custom_operator("<<", 200);
engine.register_fn("gt", script_functions::gt_op);
engine.register_fn("gte", script_functions::gte_op);
engine
.register_custom_syntax(
["from", "$ident$", "$expr$", ":", "$expr$"], // the custom syntax
true, // variables declared within this custom syntax
implementation_df_select,
)
.unwrap();
}
#[derive(Clone, Debug, PartialEq)]
pub struct DataFrame(polars::frame::DataFrame);
impl Deref for DataFrame {
type Target = polars::frame::DataFrame;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DataFrame {
pub fn s_column(&mut self, name: &str) -> ScriptResult<Series> {
Ok(Series(
self.0.column(name).map_err(|e| e.to_string())?.clone(),
))
}
pub fn load_csv(path: &str) -> ScriptResult<DataFrame> {
let df = polars::io::csv::CsvReader::from_path(path)
.map_err(|e| e.to_string())?
.infer_schema(Some(1))
.with_ignore_parser_errors(true)
.has_header(true)
.finish()
.map_err(|e| e.to_string())?;
Ok(DataFrame(df))
}
pub fn s_select(&mut self, expressions: rhai::Array) -> ScriptResult<DataFrame> {
let expressions = expressions
.into_iter()
.filter_map(|i| i.try_cast::<DataFrameExpression>())
.map(polars::lazy::dsl::Expr::from)
.collect::<Vec<_>>();
Ok(DataFrame(
self.0
.clone()
.lazy()
.select(expressions)
.collect()
.map_err(|e| e.to_string())?,
))
}
pub fn append_column(&mut self, idx: &str, mut b: Series) -> ScriptResult<()> {
b.0.rename(idx);
self.0.with_column(b.0).map_err(|e| e.to_string())?;
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct DataFrameExpression(polars::lazy::dsl::Expr);
impl From<DataFrameExpression> for polars::lazy::dsl::Expr {
fn from(source: DataFrameExpression) -> Self {
source.0
}
}
impl DataFrameExpression {
fn first(&mut self) -> Self {
DataFrameExpression(self.0.clone().first())
}
fn sum(&mut self) -> Self {
DataFrameExpression(self.0.clone().sum())
}
fn mean(&mut self) -> Self {
DataFrameExpression(self.0.clone().mean())
}
fn max(&mut self) -> Self {
DataFrameExpression(self.0.clone().max())
}
fn min(&mut self) -> Self {
DataFrameExpression(self.0.clone().min())
}
fn count(&mut self) -> Self {
DataFrameExpression(self.0.clone().count())
}
fn eq(&mut self, e: rhai::Dynamic) -> Self {
DataFrameExpression(self.0.clone().eq(DataFrameExpression::from(e)))
}
fn gt(&mut self, e: rhai::Dynamic) -> Self {
DataFrameExpression(self.0.clone().gt(DataFrameExpression::from(e)))
}
fn lt(&mut self, e: rhai::Dynamic) -> Self {
DataFrameExpression(self.0.clone().lt(DataFrameExpression::from(e)))
}
fn gte(&mut self, e: rhai::Dynamic) -> Self {
DataFrameExpression(self.0.clone().gt_eq(DataFrameExpression::from(e)))
}
fn lte(&mut self, e: rhai::Dynamic) -> Self {
DataFrameExpression(self.0.clone().lt_eq(DataFrameExpression::from(e)))
}
pub fn filter(&mut self, expr: DataFrameExpression) -> DataFrameExpression {
DataFrameExpression(self.0.clone().filter(expr.0))
}
}
impl From<rhai::Dynamic> for DataFrameExpression {
fn from(source: rhai::Dynamic) -> Self {
match source {
x if x.is::<i64>() => DataFrameExpression(x.as_int().unwrap().into()),
x if x.is::<f64>() => DataFrameExpression(x.as_float().unwrap().into()),
x if x.is::<String>() => DataFrameExpression(x.into_string().unwrap()[..].into()),
x if x.is::<bool>() => DataFrameExpression(x.as_bool().unwrap().into()),
_ => DataFrameExpression(polars::lazy::dsl::Expr::default()),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Series(polars::series::Series);
impl Deref for Series {
type Target = polars::series::Series;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Add for Series {
type Output = Series;
fn add(self, rhs: Self) -> Self::Output {
Series(self.0 + rhs.0)
}
}
impl Series {
pub fn s_head(&mut self, n: i64) -> Series {
Series(self.0.head(Some(n as usize)))
}
pub fn s_sort(&mut self, reverse: bool) -> Series {
Series(self.0.sort(reverse))
}
pub fn s_sum(&mut self) -> i64 {
self.0.sum().unwrap_or_default()
}
pub fn s_get(&mut self, n: i64) -> ScriptResult<rhai::Dynamic> {
let value = self.get(n as usize);
match value {
polars::datatypes::AnyValue::Utf8(v) => {
Ok(rhai::Dynamic::from_str(v).unwrap_or_default())
}
polars::prelude::AnyValue::Null => Ok(rhai::Dynamic::UNIT),
polars::prelude::AnyValue::Boolean(v) => Ok(rhai::Dynamic::from_bool(v)),
polars::prelude::AnyValue::UInt8(v) => Ok(rhai::Dynamic::from_int(v as i64)),
polars::prelude::AnyValue::UInt16(v) => Ok(rhai::Dynamic::from_int(v as i64)),
polars::prelude::AnyValue::UInt32(v) => Ok(rhai::Dynamic::from_int(v as i64)),
polars::prelude::AnyValue::UInt64(v) => Ok(rhai::Dynamic::from_int(v as i64)),
polars::prelude::AnyValue::Int8(v) => Ok(rhai::Dynamic::from_int(v as i64)),
polars::prelude::AnyValue::Int16(v) => Ok(rhai::Dynamic::from_int(v as i64)),
polars::prelude::AnyValue::Int32(v) => Ok(rhai::Dynamic::from_int(v as i64)),
polars::prelude::AnyValue::Int64(v) => Ok(rhai::Dynamic::from_int(v as i64)),
polars::prelude::AnyValue::Float32(v) => Ok(rhai::Dynamic::from_float(v as f64)),
polars::prelude::AnyValue::Float64(v) => Ok(rhai::Dynamic::from_float(v as f64)),
polars::prelude::AnyValue::Date(v) => Ok(rhai::Dynamic::from_int(v as i64)),
polars::prelude::AnyValue::Datetime(_, _, _) => Ok(rhai::Dynamic::from_int(0)),
polars::prelude::AnyValue::Duration(_, _) => Ok(rhai::Dynamic::from_int(0)),
polars::prelude::AnyValue::Time(_) => Ok(rhai::Dynamic::from_int(0)),
polars::prelude::AnyValue::List(s) => Ok(rhai::Dynamic::from(Series(s))),
polars::prelude::AnyValue::Utf8Owned(v) => {
Ok(rhai::Dynamic::from_str(&v).unwrap_or_default())
}
}
}
pub fn s_op_add(self, series: Series) -> Series {
Series(self.0 + series.0)
}
}
mod script_functions {
use super::*;
pub fn add_series_i64(a: Series, b: i64) -> Series {
Series(a.0 + b)
}
pub fn add_series_f64(a: Series, b: f64) -> Series {
Series(a.0 + b)
}
pub fn subtract_series_i64(a: Series, b: i64) -> Series {
Series(a.0 - b)
}
pub fn subtract_series_f64(a: Series, b: f64) -> Series {
Series(a.0 - b)
}
pub fn subtract_series_series(a: Series, b: Series) -> Series {
Series(a.0 - b.0)
}
pub fn power_series_i64(a: Series, b: i64) -> ScriptResult<Series> {
let name = a.name();
let a_series = a.0.clone();
let df = a_series
.into_frame()
.lazy()
.select([polars::prelude::col(name).pow(b)])
.collect()
.map_err(|e| e.to_string())?;
let s = df.column(name).map_err(|e| e.to_string())?;
Ok(Series(s.clone()))
}
pub fn power_series_f64(a: Series, b: f64) -> ScriptResult<Series> {
let name = a.name();
let a_series = a.0.clone();
let df = a_series
.into_frame()
.lazy()
.select([polars::prelude::col(name).pow(b)])
.collect()
.map_err(|e| e.to_string())?;
let s = df.column(name).map_err(|e| e.to_string())?;
Ok(Series(s.clone()))
}
pub fn multiply_series_i64(a: Series, b: i64) -> Series {
Series(a.0 * b)
}
pub fn multiply_series_f64(a: Series, b: f64) -> Series {
Series(a.0 * b)
}
pub fn multiply_series_series(a: Series, b: Series) -> Series {
Series(a.0 * b.0)
}
pub fn div_series_i64(a: Series, b: i64) -> Series {
Series(a.0 / b)
}
pub fn div_series_f64(a: Series, b: f64) -> Series {
Series(a.0 / b)
}
pub fn div_series_series(a: Series, b: Series) -> Series {
Series(a.0 / b.0)
}
pub fn column(name: &str) -> DataFrameExpression {
DataFrameExpression(polars::prelude::col(name))
}
pub fn sum(name: &str) -> DataFrameExpression {
DataFrameExpression(polars::prelude::sum(name))
}
pub fn min(name: &str) -> DataFrameExpression {
DataFrameExpression(polars::prelude::min(name))
}
pub fn max(name: &str) -> DataFrameExpression {
DataFrameExpression(polars::prelude::max(name))
}
pub fn mean(name: &str) -> DataFrameExpression {
DataFrameExpression(polars::prelude::mean(name))
}
pub fn first() -> DataFrameExpression {
DataFrameExpression(polars::prelude::first())
}
pub fn count() -> DataFrameExpression {
DataFrameExpression(polars::prelude::count())
}
pub fn dataframe(mp: rhai::Map) -> ScriptResult<DataFrame> {
let mut series_vec = vec![];
for (key, value) in mp.into_iter() {
if let Some(array) = value.try_cast::<rhai::Array>() {
let s = series(&key, array)?;
series_vec.push(s);
}
}
if let Ok(df) = polars::frame::DataFrame::new(series_vec.into_iter().map(|i| i.0).collect())
{
Ok(DataFrame(df))
} else {
Ok(DataFrame(polars::frame::DataFrame::default()))
}
}
pub fn series_unnamed(arr: rhai::Array) -> std::result::Result<Series, Box<EvalAltResult>> {
series("Unnamed", arr)
}
pub fn to_series_unnamed(arr: rhai::Array) -> std::result::Result<Series, Box<EvalAltResult>> {
series("Unnamed", arr)
}
pub fn to_series(
arr: rhai::Array,
name: &str,
) -> std::result::Result<Series, Box<EvalAltResult>> {
series(name, arr)
}
pub fn series(name: &str, arr: rhai::Array) -> std::result::Result<Series, Box<EvalAltResult>> {
if let Some(i) = arr.first() {
let series = if i.type_id() == TypeId::of::<i64>() {
polars::series::Series::new(
name,
arr.into_iter()
.filter_map(|i| i.try_cast::<i64>())
.collect::<Vec<i64>>(),
)
} else if i.type_id() == TypeId::of::<f64>() {
polars::series::Series::new(
name,
arr.into_iter()
.filter_map(|i| i.try_cast::<f64>())
.collect::<Vec<f64>>(),
)
} else {
polars::series::Series::new(
name,
arr.into_iter()
.filter_map(|i| i.try_cast::<String>())
.collect::<Vec<String>>(),
)
};
Ok(Series(series))
} else {
let v: Vec<i64> = vec![];
let s = polars::series::Series::new(name, v);
Ok(Series(s))
}
}
pub fn gt_op(a: &str, b: rhai::Dynamic) -> DataFrameExpression {
DataFrameExpression(polars::prelude::col(a).gt(DataFrameExpression::from(b)))
}
pub fn gte_op(a: &str, b: rhai::Dynamic) -> DataFrameExpression {
DataFrameExpression(polars::prelude::col(a).gt_eq(DataFrameExpression::from(b)))
}
}
fn implementation_df_select(
context: &mut EvalContext,
inputs: &[Expression],
) -> Result<Dynamic, Box<EvalAltResult>> {
let df_name = inputs[0].get_string_value().ok_or_else(|| {
Box::new(EvalAltResult::ErrorVariableNotFound(
"variable not found".to_string(),
Position::default(),
))
})?;
let df = context
.scope()
.get(df_name)
.ok_or_else(|| {
Box::new(EvalAltResult::ErrorVariableNotFound(
format!("{} not found", df_name),
Position::default(),
))
})?
.clone()
.try_cast::<DataFrame>()
.ok_or_else(|| {
Box::new(EvalAltResult::ErrorVariableNotFound(
format!("{} not found", df_name),
Position::default(),
))
})?;
let raw_filter_array = context.eval_expression_tree(&inputs[2])?;
let filter_array = raw_filter_array
.into_array()
.map_err(|e| {
Box::new(EvalAltResult::ErrorVariableNotFound(
format!("{} value not an array", e),
Position::default(),
))
})?
.into_iter()
.map(|i| i.cast::<DataFrameExpression>())
.collect::<Vec<_>>();
let raw_select_array = context.eval_expression_tree(&inputs[1])?;
let select_array = raw_select_array.into_array().map_err(|e| {
Box::new(EvalAltResult::ErrorVariableNotFound(
format!("{} value not an array", e),
Position::default(),
))
})?;
let select_expressions = select_array
.iter()
.filter_map(|s| {
let select_expr = if s.is::<String>() {
polars::prelude::col(&s.to_string())
} else if s.is::<DataFrameExpression>() {
s.clone().cast::<DataFrameExpression>().0
} else {
return None;
};
Some(
filter_array
.iter()
.fold(select_expr, |acc, i| acc.filter(i.0.clone())),
)
})
.collect::<Vec<_>>();
Ok(Dynamic::from(DataFrame(
df.0.lazy()
.select(&select_expressions)
.collect()
.map_err(|e| e.to_string())?,
)))
}
#[cfg(test)]
mod tests {
use crate::engine::tests::process;
#[test]
pub fn simple_dataframe() {
let frame = process(r#"dataframe(#{ floats: [1.0, 2.0, 3.0], ints: [1,2,3], strings: ["one", "two", "three"] }) "#).into_frame();
let s = frame
.column("floats")
.unwrap()
.f64()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(s, vec![Some(1.), Some(2.), Some(3.)]);
let s = frame
.column("ints")
.unwrap()
.i64()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(s, vec![Some(1), Some(2), Some(3)]);
let s = frame
.column("strings")
.unwrap()
.utf8()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(s, vec![Some("one"), Some("two"), Some("three")]);
}
#[test]
pub fn dataframe_get_column() {
let series =
process(r#"dataframe(#{ item: [1.0, 2.0, 3.0] }).column("item") "#).into_series();
let s = series.f64().unwrap().into_iter().collect::<Vec<_>>();
assert_eq!(s, vec![Some(1.), Some(2.), Some(3.)]);
}
#[test]
pub fn simple_series() {
let series = process(r#"series("ages;", [18.,21.,25.,35.])"#).into_series();
let s = series.f64().unwrap().into_iter().collect::<Vec<_>>();
assert_eq!(s, vec![Some(18.), Some(21.), Some(25.), Some(35.)]);
}
#[test]
pub fn series_head() {
let series = process(r#"series("ages", [18.,21.,25.,35.]).head(1)"#).into_series();
let s = series.f64().unwrap().into_iter().collect::<Vec<_>>();
assert_eq!(s, vec![Some(18.)]);
}
#[test]
pub fn series_sort() {
let series =
process(r#"series("ages", [18.,21.,25.,35.]).sort(true).head(1)"#).into_series();
let s = series.f64().unwrap().into_iter().collect::<Vec<_>>();
assert_eq!(s, vec![Some(35.)]);
}
#[test]
pub fn series_index() {
let s = process(r#"series("ages", [18, 21, 25, 35]).sort(true)[1]"#).into_scalar();
assert_eq!(s.cast::<i64>(), 25);
}
#[test]
pub fn series_add() {
let series = process(
r#"
let s1 = series("ages", [1, 1, 1, 1]);
let s2 = series("ages", [18, 21, 25, 35]);
s1 + s2
"#,
)
.into_series();
let s = series.i64().unwrap().into_iter().collect::<Vec<_>>();
assert_eq!(s, vec![Some(19), Some(22), Some(26), Some(36)]);
}
#[test]
pub fn test_load_csv() {
let series =
process(r#"load_csv("test/data.csv").column("age").sort(true).head(1)"#).into_series();
let s = series.i64().unwrap().into_iter().collect::<Vec<_>>();
assert_eq!(s, vec![Some(32)]);
}
#[test]
pub fn test_dataframe_select() {
let res = process(r#" let data = load_csv("test/data.csv"); data.select([sum("age")]) "#)
.into_frame();
let s = res
.column("age")
.unwrap()
.i64()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(s, vec![Some(72)]);
let res = process(r#" let data = load_csv("test/data.csv"); data.select([min("age")]) "#)
.into_frame();
let s = res
.column("age")
.unwrap()
.i64()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(s, vec![Some(18)]);
let res = process(r#" let data = load_csv("test/data.csv"); data.select([max("age")]) "#)
.into_frame();
let s = res
.column("age")
.unwrap()
.i64()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(s, vec![Some(32)]);
let res = process(r#" let data = load_csv("test/data.csv"); data.select([mean("age")]) "#)
.into_frame();
let s = res
.column("age")
.unwrap()
.f64()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(s, vec![Some(24.0)]);
}
#[test]
pub fn test_dataframe_filter() {
let res = process(r#" let data = load_csv("test/data.csv"); data.select([column("age").filter(column("age").eq(18)).sum()]) "#)
.into_frame();
let s = res
.column("age")
.unwrap()
.i64()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(s, vec![Some(18)]);
}
#[test]
pub fn test_dataframe_select_syntax() {
let res = process(
r#"
let data = load_csv("test/data.csv");
from data ["age"] : ["age" gt 18];
"#,
);
dbg!(&res);
let s = res
.into_frame()
.column("age")
.unwrap()
.i64()
.unwrap()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(s, vec![Some(22), Some(32)]);
}
}