diff --git a/Cargo.lock b/Cargo.lock index 095d7095..46fa81c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -126,6 +126,15 @@ version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" +[[package]] +name = "ar_archive_writer" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c269894b6fe5e9d7ada0cf69b5bf847ff35bc25fc271f08e1d080fce80339a" +dependencies = [ + "object 0.32.2", +] + [[package]] name = "array-init" version = "2.1.0" @@ -274,7 +283,7 @@ dependencies = [ "cfg-if", "libc", "miniz_oxide", - "object", + "object 0.36.5", "rustc-demangle", "windows-targets 0.52.6", ] @@ -445,10 +454,11 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.1" +version = "1.2.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +checksum = "cd405d82c84ff7f35739f175f67d8b9fb7687a0e84ccdc78bd3568839827cf07" dependencies = [ + "find-msvc-tools", "shlex", ] @@ -1026,6 +1036,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" + [[package]] name = "fixedbitset" version = "0.4.2" @@ -1833,6 +1849,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "object" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +dependencies = [ + "memchr", +] + [[package]] name = "object" version = "0.36.5" @@ -2219,6 +2244,16 @@ dependencies = [ "prost", ] +[[package]] +name = "psm" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d11f2fedc3b7dafdc2851bc52f277377c5473d378859be234bc7ebb593144d01" +dependencies = [ + "ar_archive_writer", + "cc", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -2403,6 +2438,26 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "recursive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote", + "syn 2.0.89", +] + [[package]] name = "redox_syscall" version = "0.5.7" @@ -2934,11 +2989,12 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.53.0" +version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" +checksum = "4591acadbcf52f0af60eafbb2c003232b2b4cd8de5f0e9437cb8b1b59046cc0f" dependencies = [ "log", + "recursive", "serde", ] @@ -2960,6 +3016,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "stacker" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1f8b29fb42aafcea4edeeb6b2f2d7ecd0d969c48b4cf0d2e64aafc471dd6e59" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + [[package]] name = "stringprep" version = "0.1.5" diff --git a/Cargo.toml b/Cargo.toml index df134515..fdce689d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,7 +66,7 @@ serde = { version = "1", features = ["derive", "rc"] } serde_json = "1" smallvec = { version = "1", features = ["serde"] } sqllogictest = "0.23" -sqlparser = { version = "0.53", features = ["serde"] } +sqlparser = { version = "0.59.0", features = ["serde"] } thiserror = "2" tikv-jemallocator = { version = "0.6", optional = true, features = [ "disable_initial_exec_tls", diff --git a/src/binder/create_function.rs b/src/binder/create_function.rs index edce65b6..09b7208c 100644 --- a/src/binder/create_function.rs +++ b/src/binder/create_function.rs @@ -88,8 +88,14 @@ impl Binder { let body = match function_body { Some(CreateFunctionBody::AsBeforeOptions(expr)) | Some(CreateFunctionBody::AsAfterOptions(expr)) => match expr { - Expr::Value(Value::SingleQuotedString(s)) => s, - Expr::Value(Value::DollarQuotedString(s)) => s.value, + Expr::Value(vlaue_with_span) => match vlaue_with_span.value { + Value::SingleQuotedString(s) => s, + Value::DollarQuotedString(s) => s.value, + _ => { + return Err(ErrorKind::BindFunctionError("expected string".into()) + .with_spanned(&vlaue_with_span)); + } + }, _ => { return Err( ErrorKind::BindFunctionError("expected string".into()).with_spanned(&expr) @@ -101,7 +107,7 @@ impl Binder { // will NOT involve complex syntax, so just reuse the logic for select definition format!("select {}", &return_expr.to_string()) } - None => { + _ => { return Err(ErrorKind::BindFunctionError( "AS or RETURN must be specified".to_string(), ) diff --git a/src/binder/create_index.rs b/src/binder/create_index.rs index 83d52cf4..42bca25e 100644 --- a/src/binder/create_index.rs +++ b/src/binder/create_index.rs @@ -80,7 +80,11 @@ impl FromStr for Box { } impl Binder { - fn parse_index_type(&self, using: Option, with: Vec) -> Result { + fn parse_index_type( + &self, + using: Option, + with: Vec, + ) -> Result { let Some(using) = using else { return Err(ErrorKind::InvalidIndex("using clause is required".to_string()).into()); }; @@ -113,7 +117,7 @@ impl Binder { ErrorKind::InvalidIndex("invalid with clause".to_string()).into() ); }; - let v: DataValue = v.into(); + let v: DataValue = v.value.into(); match key.as_str() { "distfn" => { let v = v.as_str(); @@ -183,7 +187,7 @@ impl Binder { let mut column_ids = Vec::new(); for column in &columns { // Ensure column expr is a column reference - let OrderByExpr { expr, .. } = column; + let OrderByExpr { expr, .. } = &column.column; let Expr::Identifier(column_name) = expr else { return Err( ErrorKind::InvalidColumn("column reference expected".to_string()) diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 04f91ac3..eea7b7be 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -85,7 +85,7 @@ impl Binder { return Err(ErrorKind::NotSupportedTSQL.into()); } - let pks_name_from_constraints = Binder::pks_name_from_constraints(&constraints); + let pks_name_from_constraints = Binder::pks_name_from_constraints(&constraints)?; if has_pk_from_column && !pks_name_from_constraints.is_empty() { // can't get primary key both from "primary key(c1, c2...)" syntax and // column's option @@ -157,15 +157,26 @@ impl Binder { } /// get the primary keys' name sorted by declaration order in "primary key(c1, c2..)" syntax. - fn pks_name_from_constraints(constraints: &[TableConstraint]) -> &[Ident] { + fn pks_name_from_constraints(constraints: &[TableConstraint]) -> Result<&[Ident]> { for constraint in constraints { match constraint { - TableConstraint::PrimaryKey { columns, .. } => return columns, + TableConstraint::PrimaryKey { columns, .. } => { + let _ = columns.iter().map(|column_index| { + Ok(match &column_index.column.expr { + Expr::Identifier(ident) => ident, + _ => { + return Err(ErrorKind::InvalidIndex( + column_index.column.expr.to_string(), + )); + } + }) + }); + } _ => continue, } } // no primary key - &[] + Ok(&[]) } } diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 605359d7..9f93e48c 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -21,13 +21,13 @@ impl Binder { // parameter-like (i.e., `$1`) values at present // TODO: consider formally `bind_parameter` in the future // e.g., lambda function support, etc. - if let Value::Placeholder(key) = &v { + if let Value::Placeholder(key) = &v.value { self.udf_context .get_expr(key) .cloned() .ok_or_else(|| ErrorKind::InvalidSQL.with_spanned(&v)) } else { - Ok(self.egraph.add(Node::Constant(v.into()))) + Ok(self.egraph.add(Node::Constant(v.value.into()))) } } Expr::Identifier(ident) => self.bind_ident([ident]), @@ -44,7 +44,9 @@ impl Binder { let isnull = self.bind_is_null(*expr)?; Ok(self.egraph.add(Node::Not(isnull))) } - Expr::TypedString { data_type, value } => self.bind_typed_string(data_type, value), + Expr::TypedString(typed_string) => { + self.bind_typed_string(typed_string.data_type, typed_string.value) + } Expr::Like { negated, expr, @@ -68,9 +70,9 @@ impl Binder { Expr::Case { operand, conditions, - results, else_result, - } => self.bind_case(operand, conditions, results, else_result), + .. + } => self.bind_case(operand, conditions, else_result), Expr::InList { expr, list, @@ -143,8 +145,8 @@ impl Binder { Or => Node::Or([l, r]), Xor => Node::Xor([l, r]), Spaceship => Node::VectorCosineDistance([l, r]), + LtDashGt => Node::VectorL2Distance([l, r]), Custom(name) => match name.as_str() { - "<->" => Node::VectorL2Distance([l, r]), "<#>" => Node::VectorNegtiveInnerProduct([l, r]), op => todo!("bind custom binary op: {:?}", op), }, @@ -169,7 +171,7 @@ impl Binder { // workaround for 'BLOB' if let DataType::Custom(name, _modifiers) = &ty && name.0.len() == 1 - && name.0[0].value.to_lowercase() == "blob" + && derive_ident(&name.0[0]).value.to_lowercase() == "blob" { ty = DataType::Blob(None); } @@ -182,7 +184,11 @@ impl Binder { Ok(self.egraph.add(Node::IsNull(expr))) } - fn bind_typed_string(&mut self, data_type: DataType, value: String) -> Result { + fn bind_typed_string(&mut self, data_type: DataType, value_with_span: ValueWithSpan) -> Result { + let ValueWithSpan { value, span } = value_with_span; + let value = value.into_string().ok_or_else(|| { + ErrorKind::InvalidExpression("must be string".to_string()).with_span(span) + })?; match data_type { DataType::Date => { let date = value.parse().map_err(|_| { @@ -234,9 +240,18 @@ impl Binder { } fn bind_interval(&mut self, interval: parser::Interval) -> Result { - let Expr::Value(Value::Number(v, _) | Value::SingleQuotedString(v)) = *interval.value - else { - panic!("interval value must be number or string"); + let Expr::Value(value_with_span) = *interval.value else { + return Err( + ErrorKind::InvalidExpression("interval value must be value".to_string()).into(), + ); + }; + + let ValueWithSpan { value, .. } = value_with_span; + + let v = match value { + Value::Number(n, _) => n, + Value::SingleQuotedString(s) => s, + _ => panic!("interval value must be number or string"), }; let num = v.parse().expect("interval value is not a number"); let value = DataValue::Interval(match interval.leading_field { @@ -257,8 +272,7 @@ impl Binder { fn bind_case( &mut self, operand: Option>, - conditions: Vec, - results: Vec, + whens: Vec, else_result: Option>, ) -> Result { let operand = operand.map(|expr| self.bind_expr(*expr)).transpose()?; @@ -266,12 +280,12 @@ impl Binder { Some(expr) => self.bind_expr(*expr)?, None => self.egraph.add(Node::null()), }; - for (cond, result) in conditions.into_iter().rev().zip(results.into_iter().rev()) { - let mut cond = self.bind_expr(cond)?; + for CaseWhen { condition, result } in whens.iter().rev() { + let mut cond = self.bind_expr(condition.clone())?; if let Some(operand) = operand { cond = self.egraph.add(Node::Eq([operand, cond])); } - let mut result = self.bind_expr(result)?; + let mut result = self.bind_expr(result.clone())?; (result, case) = self.implicit_type_cast(result, case)?; case = self.egraph.add(Node::If([cond, result, case])); } diff --git a/src/binder/insert.rs b/src/binder/insert.rs index 29ba1064..bd7c8948 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -7,11 +7,17 @@ impl Binder { let Some(source) = insert.source else { return Err(ErrorKind::InvalidSQL.with_spanned(&insert)); }; - let (table, is_internal, is_view) = self.bind_table_id(&insert.table_name)?; + let table_name = match &insert.table { + TableObject::TableName(object_name) => object_name, + TableObject::TableFunction(_) => { + return Err(ErrorKind::Todo("Table Function".into()).with_spanned(&insert.table)); + } + }; + let (table, is_internal, is_view) = self.bind_table_id(table_name)?; if is_internal || is_view { - return Err(ErrorKind::CanNotInsert.with_spanned(&insert.table_name)); + return Err(ErrorKind::CanNotInsert.with_spanned(table_name)); } - let cols = self.bind_table_columns(&insert.table_name, &insert.columns)?; + let cols = self.bind_table_columns(table_name, &insert.columns)?; let source = self.bind_query(*source)?.0; let id = self.egraph.add(Node::Insert([table, cols, source])); Ok(id) diff --git a/src/binder/mod.rs b/src/binder/mod.rs index edb901df..ab6002d3 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -260,9 +260,12 @@ impl Binder { statement, analyze, .. } => self.bind_explain(*statement, analyze), Statement::Pragma { name, value, .. } => self.bind_pragma(name, value), - Statement::SetVariable { - variables, value, .. - } => self.bind_set(variables.as_ref(), value), + Statement::Set(set) => match set { + Set::SingleAssignment { + variable, values, .. + } => self.bind_set(&[variable], values), + _ => Err(ErrorKind::Todo("other set operators".into()).into()), + }, Statement::ShowVariable { .. } | Statement::ShowCreate { .. } | Statement::ShowColumns { .. } => Err(ErrorKind::NotSupportedTSQL.into()), @@ -426,18 +429,47 @@ impl Binder { /// Split an object name into `(schema name, table name)`. fn split_name(name: &ObjectName) -> Result<(&str, &str)> { Ok(match name.0.as_slice() { - [table] => (RootCatalog::DEFAULT_SCHEMA_NAME, &table.value), - [schema, table] => (&schema.value, &table.value), - _ => return Err(ErrorKind::InvalidTableName(name.0.clone()).with_spanned(name)), + [table] => (RootCatalog::DEFAULT_SCHEMA_NAME, &derive_ident(table).value), + [schema, table] => (&derive_ident(schema).value, &derive_ident(table).value), + _ => { + return Err(ErrorKind::InvalidTableName( + name.0.iter().map(|o| derive_ident(o).clone()).collect(), + ) + .with_spanned(name)); + } }) } +/// deriver Ident from object name part +fn derive_ident(object: &ObjectNamePart) -> &Ident { + match object { + ObjectNamePart::Identifier(ident) => ident, + ObjectNamePart::Function(table_function) => &table_function.name, + } +} + /// Convert an object name into lower case fn lower_case_name(name: &ObjectName) -> ObjectName { ObjectName( name.0 .iter() - .map(|ident| Ident::with_span(ident.span, ident.value.to_lowercase())) + .map(|obj_name_part| match obj_name_part { + ObjectNamePart::Identifier(ident) => ObjectNamePart::Identifier(Ident { + value: ident.value.to_lowercase(), + quote_style: ident.quote_style, + span: ident.span, + }), + ObjectNamePart::Function(object_name_part_function) => { + ObjectNamePart::Function(ObjectNamePartFunction { + name: Ident { + value: object_name_part_function.name.value.to_lowercase(), + quote_style: object_name_part_function.name.quote_style, + span: object_name_part_function.name.span, + }, + args: object_name_part_function.args.clone(), + }) + } + }) .collect::>(), ) } diff --git a/src/binder/select.rs b/src/binder/select.rs index 4e2897a2..009e42c8 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -29,13 +29,31 @@ impl Binder { SetExpr::Values(values) => self.bind_values(values)?, body => return Err(ErrorKind::Todo("unknown set expr".into()).with_spanned(&body)), }; - let limit = match query.limit { - Some(expr) => self.bind_expr(expr)?, - None => self.egraph.add(Node::null()), - }; - let offset = match query.offset { - Some(offset) => self.bind_expr(offset.value)?, - None => self.egraph.add(Node::zero()), + let default_limit = self.egraph.add(Node::null()); + let default_offset = self.egraph.add(Node::zero()); + + let (limit, offset) = if let Some(clause) = query.limit_clause { + match clause { + LimitClause::LimitOffset { limit, offset, .. } => { + let limit_node = limit + .map(|l| self.bind_expr(l)) + .transpose()? + .unwrap_or(default_limit); + + let offset_node = offset + .map(|o| self.bind_expr(o.value)) + .transpose()? + .unwrap_or(default_offset); + + (limit_node, offset_node) + } + + LimitClause::OffsetCommaLimit { .. } => { + return Err(ErrorKind::Todo("MySQL Limit syntax".into()).with_spanned(&clause)); + } + } + } else { + (default_limit, default_offset) }; Ok(self.egraph.add(Node::Limit([limit, offset, child]))) } @@ -90,7 +108,12 @@ impl Binder { }; let having = self.bind_having(select.having)?; let orderby = match order_by { - Some(order_by) => self.bind_orderby(order_by.exprs)?, + Some(order_by) => match order_by.kind { + OrderByKind::All(_) => { + return Err(ErrorKind::Todo("order by all".into()).with_spanned(&order_by)); + } + OrderByKind::Expressions(exprs) => self.bind_orderby(exprs)?, + }, None => self.egraph.add(Node::List([].into())), }; let distinct = match select.distinct { @@ -195,7 +218,7 @@ impl Binder { let mut orderby = Vec::with_capacity(order_by.len()); for e in order_by { let expr = self.bind_expr(e.expr)?; - let key = match e.asc { + let key = match e.options.asc { Some(true) | None => expr, Some(false) => self.egraph.add(Node::Desc(expr)), }; diff --git a/src/binder/table.rs b/src/binder/table.rs index e01cf279..3a91f148 100644 --- a/src/binder/table.rs +++ b/src/binder/table.rs @@ -89,17 +89,17 @@ impl Binder { fn bind_join_op(&mut self, op: JoinOperator) -> Result<(Id, Id)> { use JoinOperator::*; match op { - Inner(constraint) => { + Join(constraint) | Inner(constraint) => { let ty = self.egraph.add(Node::Inner); let condition = self.bind_join_constraint(constraint)?; Ok((ty, condition)) } - LeftOuter(constraint) => { + Left(constraint) | LeftOuter(constraint) => { let ty = self.egraph.add(Node::LeftOuter); let condition = self.bind_join_constraint(constraint)?; Ok((ty, condition)) } - RightOuter(constraint) => { + Right(constraint) | RightOuter(constraint) => { let ty = self.egraph.add(Node::RightOuter); let condition = self.bind_join_constraint(constraint)?; Ok((ty, condition)) @@ -109,9 +109,9 @@ impl Binder { let condition = self.bind_join_constraint(constraint)?; Ok((ty, condition)) } - CrossJoin => { + CrossJoin(constraint) => { let ty = self.egraph.add(Node::Inner); - let condition = self.egraph.add(Node::true_()); + let condition = self.bind_join_constraint(constraint)?; Ok((ty, condition)) } LeftSemi(constraint) => { diff --git a/src/catalog/mod.rs b/src/catalog/mod.rs index 2be2f1f1..ed2481cd 100644 --- a/src/catalog/mod.rs +++ b/src/catalog/mod.rs @@ -35,8 +35,7 @@ pub struct TableRefId { impl std::fmt::Debug for TableRefId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // TODO: print schema id - write!(f, "${}", self.table_id) + write!(f, "${}.{}", self.schema_id, self.table_id) } } @@ -126,8 +125,11 @@ impl ColumnRefId { impl std::fmt::Debug for ColumnRefId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // TODO: print schema id - write!(f, "${}.{}", self.table_id, self.column_id)?; + write!( + f, + "${}.{}.{}", + self.schema_id, self.table_id, self.column_id + )?; if self.table_occurrence != 0 { write!(f, "({})", self.table_occurrence)?; } diff --git a/src/types/mod.rs b/src/types/mod.rs index f0ee6020..838831e3 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -114,7 +114,7 @@ impl From<&crate::parser::DataType> for DataType { Char(_) | Varchar(_) | String(_) | Text => Self::String, Bytea | Binary(_) | Varbinary(_) | Blob(_) => Self::Blob, // Real => Self::Float32, - Float(_) | Double => Self::Float64, + Float(_) | Double(_) => Self::Float64, SmallInt(_) => Self::Int16, Int(_) | Integer(_) => Self::Int32, BigInt(_) => Self::Int64, @@ -129,7 +129,7 @@ impl From<&crate::parser::DataType> for DataType { Date => Self::Date, Timestamp(_, TimezoneInfo::None) => Self::Timestamp, Timestamp(_, TimezoneInfo::Tz) => Self::TimestampTz, - Interval => Self::Interval, + Interval { .. } => Self::Interval, Custom(name, items) => { if name.to_string().to_lowercase() == "vector" { if items.len() != 1 {