use super::ScriptResult; use polars::prelude::{IntoLazy, NamedFrom, SerReader}; use std::{ any::TypeId, ops::{Add, Deref}, str::FromStr, }; use rhai::{Dynamic, EvalAltResult, EvalContext, Expression, Position}; pub fn setup_engine(engine: &mut rhai::Engine) { //polar data frame engine.register_type::(); engine.register_fn("dataframe", script_functions::dataframe); engine.register_fn("select", DataFrame::s_select); engine.register_fn("load_csv", DataFrame::load_csv); engine.register_fn("column", DataFrame::s_column); engine.register_indexer_get(DataFrame::s_column); engine.register_indexer_set(DataFrame::append_column); // polar series engine.register_type::(); engine.register_indexer_get(Series::s_get); engine.register_fn("series", script_functions::series); engine.register_fn("to_series", script_functions::to_series); engine.register_fn("to_series", script_functions::to_series_unnamed); engine.register_fn("series", script_functions::series_unnamed); engine.register_fn("head", Series::s_head); engine.register_fn("sort", Series::s_sort); engine.register_fn("sum", Series::s_sum); engine.register_fn("add", Series::s_op_add); engine.register_fn("+", Series::add); engine.register_fn("+", script_functions::add_series_i64); engine.register_fn("+", script_functions::add_series_f64); engine.register_fn("-", script_functions::subtract_series_series); engine.register_fn("-", script_functions::subtract_series_i64); engine.register_fn("-", script_functions::subtract_series_f64); engine.register_fn("**", script_functions::power_series_f64); engine.register_fn("**", script_functions::power_series_i64); engine.register_fn("*", script_functions::multiply_series_series); engine.register_fn("*", script_functions::multiply_series_i64); engine.register_fn("*", script_functions::multiply_series_f64); engine.register_fn("/", script_functions::div_series_series); engine.register_fn("/", script_functions::div_series_i64); engine.register_fn("/", script_functions::div_series_f64); // polar expressions engine.register_type::(); engine.register_fn("sum", DataFrameExpression::sum); engine.register_fn("mean", DataFrameExpression::mean); engine.register_fn("min", DataFrameExpression::min); engine.register_fn("max", DataFrameExpression::max); engine.register_fn("first", DataFrameExpression::first); engine.register_fn("count", DataFrameExpression::count); engine.register_fn("eq", DataFrameExpression::eq); engine.register_fn("lt", DataFrameExpression::lt); engine.register_fn("lte", DataFrameExpression::lte); engine.register_fn("gt", DataFrameExpression::gt); engine.register_fn("gte", DataFrameExpression::gte); engine.register_fn("filter", DataFrameExpression::filter); engine.register_fn("column", script_functions::column); engine.register_fn("count", script_functions::count); engine.register_fn("sum", script_functions::sum); engine.register_fn("mean", script_functions::mean); engine.register_fn("min", script_functions::min); engine.register_fn("max", script_functions::max); engine.register_fn("first", script_functions::first); let _ = engine.register_custom_operator("gt", 200); let _ = engine.register_custom_operator("gte", 200); let _ = engine.register_custom_operator("<<", 200); engine.register_fn("gt", script_functions::gt_op); engine.register_fn("gte", script_functions::gte_op); engine .register_custom_syntax( ["from", "$ident$", "$expr$", ":", "$expr$"], // the custom syntax true, // variables declared within this custom syntax implementation_df_select, ) .unwrap(); } #[derive(Clone, Debug, PartialEq)] pub struct DataFrame(polars::frame::DataFrame); impl Deref for DataFrame { type Target = polars::frame::DataFrame; fn deref(&self) -> &Self::Target { &self.0 } } impl DataFrame { pub fn s_column(&mut self, name: &str) -> ScriptResult { Ok(Series( self.0.column(name).map_err(|e| e.to_string())?.clone(), )) } pub fn load_csv(path: &str) -> ScriptResult { let df = polars::io::csv::CsvReader::from_path(path) .map_err(|e| e.to_string())? .infer_schema(Some(1)) .with_ignore_parser_errors(true) .has_header(true) .finish() .map_err(|e| e.to_string())?; Ok(DataFrame(df)) } pub fn s_select(&mut self, expressions: rhai::Array) -> ScriptResult { let expressions = expressions .into_iter() .filter_map(|i| i.try_cast::()) .map(polars::lazy::dsl::Expr::from) .collect::>(); Ok(DataFrame( self.0 .clone() .lazy() .select(expressions) .collect() .map_err(|e| e.to_string())?, )) } pub fn append_column(&mut self, idx: &str, mut b: Series) -> ScriptResult<()> { b.0.rename(idx); self.0.with_column(b.0).map_err(|e| e.to_string())?; Ok(()) } } #[derive(Clone, Debug)] pub struct DataFrameExpression(polars::lazy::dsl::Expr); impl From for polars::lazy::dsl::Expr { fn from(source: DataFrameExpression) -> Self { source.0 } } impl DataFrameExpression { fn first(&mut self) -> Self { DataFrameExpression(self.0.clone().first()) } fn sum(&mut self) -> Self { DataFrameExpression(self.0.clone().sum()) } fn mean(&mut self) -> Self { DataFrameExpression(self.0.clone().mean()) } fn max(&mut self) -> Self { DataFrameExpression(self.0.clone().max()) } fn min(&mut self) -> Self { DataFrameExpression(self.0.clone().min()) } fn count(&mut self) -> Self { DataFrameExpression(self.0.clone().count()) } fn eq(&mut self, e: rhai::Dynamic) -> Self { DataFrameExpression(self.0.clone().eq(DataFrameExpression::from(e))) } fn gt(&mut self, e: rhai::Dynamic) -> Self { DataFrameExpression(self.0.clone().gt(DataFrameExpression::from(e))) } fn lt(&mut self, e: rhai::Dynamic) -> Self { DataFrameExpression(self.0.clone().lt(DataFrameExpression::from(e))) } fn gte(&mut self, e: rhai::Dynamic) -> Self { DataFrameExpression(self.0.clone().gt_eq(DataFrameExpression::from(e))) } fn lte(&mut self, e: rhai::Dynamic) -> Self { DataFrameExpression(self.0.clone().lt_eq(DataFrameExpression::from(e))) } pub fn filter(&mut self, expr: DataFrameExpression) -> DataFrameExpression { DataFrameExpression(self.0.clone().filter(expr.0)) } } impl From for DataFrameExpression { fn from(source: rhai::Dynamic) -> Self { match source { x if x.is::() => DataFrameExpression(x.as_int().unwrap().into()), x if x.is::() => DataFrameExpression(x.as_float().unwrap().into()), x if x.is::() => DataFrameExpression(x.into_string().unwrap()[..].into()), x if x.is::() => DataFrameExpression(x.as_bool().unwrap().into()), _ => DataFrameExpression(polars::lazy::dsl::Expr::default()), } } } #[derive(Clone, Debug, PartialEq)] pub struct Series(polars::series::Series); impl Deref for Series { type Target = polars::series::Series; fn deref(&self) -> &Self::Target { &self.0 } } impl Add for Series { type Output = Series; fn add(self, rhs: Self) -> Self::Output { Series(self.0 + rhs.0) } } impl Series { pub fn s_head(&mut self, n: i64) -> Series { Series(self.0.head(Some(n as usize))) } pub fn s_sort(&mut self, reverse: bool) -> Series { Series(self.0.sort(reverse)) } pub fn s_sum(&mut self) -> i64 { self.0.sum().unwrap_or_default() } pub fn s_get(&mut self, n: i64) -> ScriptResult { let value = self.get(n as usize); match value { polars::datatypes::AnyValue::Utf8(v) => { Ok(rhai::Dynamic::from_str(v).unwrap_or_default()) } polars::prelude::AnyValue::Null => Ok(rhai::Dynamic::UNIT), polars::prelude::AnyValue::Boolean(v) => Ok(rhai::Dynamic::from_bool(v)), polars::prelude::AnyValue::UInt8(v) => Ok(rhai::Dynamic::from_int(v as i64)), polars::prelude::AnyValue::UInt16(v) => Ok(rhai::Dynamic::from_int(v as i64)), polars::prelude::AnyValue::UInt32(v) => Ok(rhai::Dynamic::from_int(v as i64)), polars::prelude::AnyValue::UInt64(v) => Ok(rhai::Dynamic::from_int(v as i64)), polars::prelude::AnyValue::Int8(v) => Ok(rhai::Dynamic::from_int(v as i64)), polars::prelude::AnyValue::Int16(v) => Ok(rhai::Dynamic::from_int(v as i64)), polars::prelude::AnyValue::Int32(v) => Ok(rhai::Dynamic::from_int(v as i64)), polars::prelude::AnyValue::Int64(v) => Ok(rhai::Dynamic::from_int(v as i64)), polars::prelude::AnyValue::Float32(v) => Ok(rhai::Dynamic::from_float(v as f64)), polars::prelude::AnyValue::Float64(v) => Ok(rhai::Dynamic::from_float(v as f64)), polars::prelude::AnyValue::Date(v) => Ok(rhai::Dynamic::from_int(v as i64)), polars::prelude::AnyValue::Datetime(_, _, _) => Ok(rhai::Dynamic::from_int(0)), polars::prelude::AnyValue::Duration(_, _) => Ok(rhai::Dynamic::from_int(0)), polars::prelude::AnyValue::Time(_) => Ok(rhai::Dynamic::from_int(0)), polars::prelude::AnyValue::List(s) => Ok(rhai::Dynamic::from(Series(s))), polars::prelude::AnyValue::Utf8Owned(v) => { Ok(rhai::Dynamic::from_str(&v).unwrap_or_default()) } } } pub fn s_op_add(self, series: Series) -> Series { Series(self.0 + series.0) } } mod script_functions { use super::*; pub fn add_series_i64(a: Series, b: i64) -> Series { Series(a.0 + b) } pub fn add_series_f64(a: Series, b: f64) -> Series { Series(a.0 + b) } pub fn subtract_series_i64(a: Series, b: i64) -> Series { Series(a.0 - b) } pub fn subtract_series_f64(a: Series, b: f64) -> Series { Series(a.0 - b) } pub fn subtract_series_series(a: Series, b: Series) -> Series { Series(a.0 - b.0) } pub fn power_series_i64(a: Series, b: i64) -> ScriptResult { let name = a.name(); let a_series = a.0.clone(); let df = a_series .into_frame() .lazy() .select([polars::prelude::col(name).pow(b)]) .collect() .map_err(|e| e.to_string())?; let s = df.column(name).map_err(|e| e.to_string())?; Ok(Series(s.clone())) } pub fn power_series_f64(a: Series, b: f64) -> ScriptResult { let name = a.name(); let a_series = a.0.clone(); let df = a_series .into_frame() .lazy() .select([polars::prelude::col(name).pow(b)]) .collect() .map_err(|e| e.to_string())?; let s = df.column(name).map_err(|e| e.to_string())?; Ok(Series(s.clone())) } pub fn multiply_series_i64(a: Series, b: i64) -> Series { Series(a.0 * b) } pub fn multiply_series_f64(a: Series, b: f64) -> Series { Series(a.0 * b) } pub fn multiply_series_series(a: Series, b: Series) -> Series { Series(a.0 * b.0) } pub fn div_series_i64(a: Series, b: i64) -> Series { Series(a.0 / b) } pub fn div_series_f64(a: Series, b: f64) -> Series { Series(a.0 / b) } pub fn div_series_series(a: Series, b: Series) -> Series { Series(a.0 / b.0) } pub fn column(name: &str) -> DataFrameExpression { DataFrameExpression(polars::prelude::col(name)) } pub fn sum(name: &str) -> DataFrameExpression { DataFrameExpression(polars::prelude::sum(name)) } pub fn min(name: &str) -> DataFrameExpression { DataFrameExpression(polars::prelude::min(name)) } pub fn max(name: &str) -> DataFrameExpression { DataFrameExpression(polars::prelude::max(name)) } pub fn mean(name: &str) -> DataFrameExpression { DataFrameExpression(polars::prelude::mean(name)) } pub fn first() -> DataFrameExpression { DataFrameExpression(polars::prelude::first()) } pub fn count() -> DataFrameExpression { DataFrameExpression(polars::prelude::count()) } pub fn dataframe(mp: rhai::Map) -> ScriptResult { let mut series_vec = vec![]; for (key, value) in mp.into_iter() { if let Some(array) = value.try_cast::() { let s = series(&key, array)?; series_vec.push(s); } } if let Ok(df) = polars::frame::DataFrame::new(series_vec.into_iter().map(|i| i.0).collect()) { Ok(DataFrame(df)) } else { Ok(DataFrame(polars::frame::DataFrame::default())) } } pub fn series_unnamed(arr: rhai::Array) -> std::result::Result> { series("Unnamed", arr) } pub fn to_series_unnamed(arr: rhai::Array) -> std::result::Result> { series("Unnamed", arr) } pub fn to_series( arr: rhai::Array, name: &str, ) -> std::result::Result> { series(name, arr) } pub fn series(name: &str, arr: rhai::Array) -> std::result::Result> { if let Some(i) = arr.first() { let series = if i.type_id() == TypeId::of::() { polars::series::Series::new( name, arr.into_iter() .filter_map(|i| i.try_cast::()) .collect::>(), ) } else if i.type_id() == TypeId::of::() { polars::series::Series::new( name, arr.into_iter() .filter_map(|i| i.try_cast::()) .collect::>(), ) } else { polars::series::Series::new( name, arr.into_iter() .filter_map(|i| i.try_cast::()) .collect::>(), ) }; Ok(Series(series)) } else { let v: Vec = vec![]; let s = polars::series::Series::new(name, v); Ok(Series(s)) } } pub fn gt_op(a: &str, b: rhai::Dynamic) -> DataFrameExpression { DataFrameExpression(polars::prelude::col(a).gt(DataFrameExpression::from(b))) } pub fn gte_op(a: &str, b: rhai::Dynamic) -> DataFrameExpression { DataFrameExpression(polars::prelude::col(a).gt_eq(DataFrameExpression::from(b))) } } fn implementation_df_select( context: &mut EvalContext, inputs: &[Expression], ) -> Result> { let df_name = inputs[0].get_string_value().ok_or_else(|| { Box::new(EvalAltResult::ErrorVariableNotFound( "variable not found".to_string(), Position::default(), )) })?; let df = context .scope() .get(df_name) .ok_or_else(|| { Box::new(EvalAltResult::ErrorVariableNotFound( format!("{} not found", df_name), Position::default(), )) })? .clone() .try_cast::() .ok_or_else(|| { Box::new(EvalAltResult::ErrorVariableNotFound( format!("{} not found", df_name), Position::default(), )) })?; let raw_filter_array = context.eval_expression_tree(&inputs[2])?; let filter_array = raw_filter_array .into_array() .map_err(|e| { Box::new(EvalAltResult::ErrorVariableNotFound( format!("{} value not an array", e), Position::default(), )) })? .into_iter() .map(|i| i.cast::()) .collect::>(); let raw_select_array = context.eval_expression_tree(&inputs[1])?; let select_array = raw_select_array.into_array().map_err(|e| { Box::new(EvalAltResult::ErrorVariableNotFound( format!("{} value not an array", e), Position::default(), )) })?; let select_expressions = select_array .iter() .filter_map(|s| { let select_expr = if s.is::() { polars::prelude::col(&s.to_string()) } else if s.is::() { s.clone().cast::().0 } else { return None; }; Some( filter_array .iter() .fold(select_expr, |acc, i| acc.filter(i.0.clone())), ) }) .collect::>(); Ok(Dynamic::from(DataFrame( df.0.lazy() .select(&select_expressions) .collect() .map_err(|e| e.to_string())?, ))) } #[cfg(test)] mod tests { use crate::engine::tests::process; #[test] pub fn simple_dataframe() { let frame = process(r#"dataframe(#{ floats: [1.0, 2.0, 3.0], ints: [1,2,3], strings: ["one", "two", "three"] }) "#).into_frame(); let s = frame .column("floats") .unwrap() .f64() .unwrap() .into_iter() .collect::>(); assert_eq!(s, vec![Some(1.), Some(2.), Some(3.)]); let s = frame .column("ints") .unwrap() .i64() .unwrap() .into_iter() .collect::>(); assert_eq!(s, vec![Some(1), Some(2), Some(3)]); let s = frame .column("strings") .unwrap() .utf8() .unwrap() .into_iter() .collect::>(); assert_eq!(s, vec![Some("one"), Some("two"), Some("three")]); } #[test] pub fn dataframe_get_column() { let series = process(r#"dataframe(#{ item: [1.0, 2.0, 3.0] }).column("item") "#).into_series(); let s = series.f64().unwrap().into_iter().collect::>(); assert_eq!(s, vec![Some(1.), Some(2.), Some(3.)]); } #[test] pub fn simple_series() { let series = process(r#"series("ages;", [18.,21.,25.,35.])"#).into_series(); let s = series.f64().unwrap().into_iter().collect::>(); assert_eq!(s, vec![Some(18.), Some(21.), Some(25.), Some(35.)]); } #[test] pub fn series_head() { let series = process(r#"series("ages", [18.,21.,25.,35.]).head(1)"#).into_series(); let s = series.f64().unwrap().into_iter().collect::>(); assert_eq!(s, vec![Some(18.)]); } #[test] pub fn series_sort() { let series = process(r#"series("ages", [18.,21.,25.,35.]).sort(true).head(1)"#).into_series(); let s = series.f64().unwrap().into_iter().collect::>(); assert_eq!(s, vec![Some(35.)]); } #[test] pub fn series_index() { let s = process(r#"series("ages", [18, 21, 25, 35]).sort(true)[1]"#).into_scalar(); assert_eq!(s.cast::(), 25); } #[test] pub fn series_add() { let series = process( r#" let s1 = series("ages", [1, 1, 1, 1]); let s2 = series("ages", [18, 21, 25, 35]); s1 + s2 "#, ) .into_series(); let s = series.i64().unwrap().into_iter().collect::>(); assert_eq!(s, vec![Some(19), Some(22), Some(26), Some(36)]); } #[test] pub fn test_load_csv() { let series = process(r#"load_csv("test/data.csv").column("age").sort(true).head(1)"#).into_series(); let s = series.i64().unwrap().into_iter().collect::>(); assert_eq!(s, vec![Some(32)]); } #[test] pub fn test_dataframe_select() { let res = process(r#" let data = load_csv("test/data.csv"); data.select([sum("age")]) "#) .into_frame(); let s = res .column("age") .unwrap() .i64() .unwrap() .into_iter() .collect::>(); assert_eq!(s, vec![Some(72)]); let res = process(r#" let data = load_csv("test/data.csv"); data.select([min("age")]) "#) .into_frame(); let s = res .column("age") .unwrap() .i64() .unwrap() .into_iter() .collect::>(); assert_eq!(s, vec![Some(18)]); let res = process(r#" let data = load_csv("test/data.csv"); data.select([max("age")]) "#) .into_frame(); let s = res .column("age") .unwrap() .i64() .unwrap() .into_iter() .collect::>(); assert_eq!(s, vec![Some(32)]); let res = process(r#" let data = load_csv("test/data.csv"); data.select([mean("age")]) "#) .into_frame(); let s = res .column("age") .unwrap() .f64() .unwrap() .into_iter() .collect::>(); assert_eq!(s, vec![Some(24.0)]); } #[test] pub fn test_dataframe_filter() { let res = process(r#" let data = load_csv("test/data.csv"); data.select([column("age").filter(column("age").eq(18)).sum()]) "#) .into_frame(); let s = res .column("age") .unwrap() .i64() .unwrap() .into_iter() .collect::>(); assert_eq!(s, vec![Some(18)]); } #[test] pub fn test_dataframe_select_syntax() { let res = process( r#" let data = load_csv("test/data.csv"); from data ["age"] : ["age" gt 18]; "#, ); dbg!(&res); let s = res .into_frame() .column("age") .unwrap() .i64() .unwrap() .into_iter() .collect::>(); assert_eq!(s, vec![Some(22), Some(32)]); } }