Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions integration/schema_sync/dev.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,20 @@ pg_dump \
--no-publications pgdog2 > destination.sql

for f in source.sql destination.sql; do
sed -i '/^\\restrict.*$/d' $f
sed -i '/^\\unrestrict.*$/d' $f
sed -i.bak '/^\\restrict.*$/d' $f
sed -i.bak '/^\\unrestrict.*$/d' $f
done

rm -f source.sql.bak destination.sql.bak

# Verify integer primary keys are rewritten to bigint, and no other differences exist
DIFF_OUTPUT=$(diff source.sql destination.sql || true)
echo "$DIFF_OUTPUT" | grep -q 'flag_id integer NOT NULL' || { echo "Expected flag_id integer->bigint rewrite"; exit 1; }
echo "$DIFF_OUTPUT" | grep -q 'flag_id bigint NOT NULL' || { echo "Expected flag_id integer->bigint rewrite"; exit 1; }
echo "$DIFF_OUTPUT" | grep -q 'setting_id integer NOT NULL' || { echo "Expected setting_id integer->bigint rewrite"; exit 1; }
echo "$DIFF_OUTPUT" | grep -q 'setting_id bigint NOT NULL' || { echo "Expected setting_id integer->bigint rewrite"; exit 1; }
sed -i.bak 's/flag_id integer NOT NULL/flag_id bigint NOT NULL/g' source.sql
sed -i.bak 's/setting_id integer NOT NULL/setting_id bigint NOT NULL/g' source.sql
rm -f source.sql.bak
diff source.sql destination.sql
rm source.sql
rm destination.sql
Expand Down
13 changes: 13 additions & 0 deletions integration/schema_sync/ecommerce_schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -992,3 +992,16 @@ COMMENT ON TABLE core.users IS 'User accounts with role-based access control';
COMMENT ON TABLE inventory.products IS 'Product catalog with full-text search capabilities';
COMMENT ON TABLE sales.orders IS 'Customer orders partitioned by creation date';
COMMENT ON TABLE inventory.stock_levels IS 'Inventory levels partitioned by warehouse';

-- Simple tables with integer primary keys
CREATE TABLE core.settings (
setting_id SERIAL PRIMARY KEY,
setting_key VARCHAR(100) NOT NULL UNIQUE,
setting_value TEXT
);

CREATE TABLE core.feature_flags (
flag_id INTEGER PRIMARY KEY,
flag_name VARCHAR(100) NOT NULL UNIQUE,
is_enabled BOOLEAN NOT NULL DEFAULT FALSE
);
190 changes: 187 additions & 3 deletions pgdog/src/backend/schema/sync/pg_dump.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
//! Wrapper around pg_dump.

use std::{ops::Deref, str::from_utf8};
use std::{
collections::{HashMap, HashSet},
ops::Deref,
str::from_utf8,
};

use lazy_static::lazy_static;
use pg_query::{
protobuf::{AlterTableType, ConstrType, ObjectType, ParseResult},
NodeEnum,
protobuf::{
AlterTableCmd, AlterTableStmt, AlterTableType, ColumnDef, ConstrType, ObjectType,
ParseResult, RangeVar, String as PgString, TypeName,
},
Node, NodeEnum,
};
use pgdog_config::QueryParserEngine;
use regex::Regex;
Expand All @@ -18,6 +25,14 @@ use crate::{
frontend::router::parser::{sequence::Sequence, Column, Table},
};

/// Key for looking up column types during pg_dump parsing.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct ColumnTypeKey<'a> {
schema: &'a str,
table: &'a str,
column: &'a str,
}

fn deparse_node(node: NodeEnum) -> Result<String, pg_query::Error> {
match config().config.general.query_parser_engine {
QueryParserEngine::PgQueryProtobuf => node.deparse(),
Expand Down Expand Up @@ -220,6 +235,8 @@ impl PgDumpOutput {
/// e.g., CREATE TABLE, primary key.
pub fn statements(&self, state: SyncState) -> Result<Vec<Statement<'_>>, Error> {
let mut result = vec![];
let mut integer_primary_keys = HashSet::<Column<'_>>::new();
let mut column_types: HashMap<ColumnTypeKey<'_>, &str> = HashMap::new();

for stmt in &self.stmts.stmts {
let (_, original_start) = self
Expand All @@ -239,6 +256,39 @@ impl PgDumpOutput {
stmt.if_not_exists = true;
deparse_node(NodeEnum::CreateStmt(stmt))?
};

// Track column types for later PRIMARY KEY detection
if let Some(ref relation) = stmt.relation {
let schema = if relation.schemaname.is_empty() {
"public"
} else {
relation.schemaname.as_str()
};
let table_name = relation.relname.as_str();

for elt in &stmt.table_elts {
if let Some(NodeEnum::ColumnDef(col_def)) = &elt.node {
if let Some(ref type_name) = col_def.type_name {
// Get the last element of the type name (e.g., "int4" from ["pg_catalog", "int4"])
if let Some(last_name) = type_name.names.last() {
if let Some(NodeEnum::String(PgString { sval })) =
&last_name.node
{
column_types.insert(
ColumnTypeKey {
schema,
table: table_name,
column: col_def.colname.as_str(),
},
sval.as_str(),
);
}
}
}
}
}
}

if state == SyncState::PreData {
// CREATE TABLE is always good.
let table =
Expand Down Expand Up @@ -291,6 +341,61 @@ impl PgDumpOutput {
| ConstrType::ConstrNotnull
| ConstrType::ConstrNull
) {
// Track INTEGER primary keys
if cons.contype()
== ConstrType::ConstrPrimary
{
if let Some(ref relation) =
stmt.relation
{
let schema = if relation
.schemaname
.is_empty()
{
"public"
} else {
relation.schemaname.as_str()
};
let table_name =
relation.relname.as_str();

for key in &cons.keys {
if let Some(NodeEnum::String(
PgString { sval },
)) = &key.node
{
let col_name =
sval.as_str();
let key = ColumnTypeKey {
schema,
table: table_name,
column: col_name,
};
if let Some(&type_name) =
column_types.get(&key)
{
// Check for INTEGER types: int4, int2, serial, smallserial
if matches!(
type_name,
"int4"
| "int2"
| "serial"
| "smallserial"
| "integer"
| "smallint"
) {
integer_primary_keys.insert(Column {
name: col_name,
table: Some(table_name),
schema: Some(schema),
});
}
}
}
}
}
}

if state == SyncState::PreData {
result.push(Statement::Other {
sql: original.to_string(),
Expand Down Expand Up @@ -545,6 +650,57 @@ impl PgDumpOutput {
}
}

// Convert INTEGER primary keys to BIGINT
if state == SyncState::PreData {
for column in &integer_primary_keys {
let alter_stmt = AlterTableStmt {
relation: Some(RangeVar {
schemaname: column.schema.unwrap_or("public").to_owned(),
relname: column.table.unwrap_or_default().to_owned(),
inh: true,
relpersistence: "p".to_owned(),
..Default::default()
}),
cmds: vec![Node {
node: Some(NodeEnum::AlterTableCmd(Box::new(AlterTableCmd {
subtype: AlterTableType::AtAlterColumnType.into(),
name: column.name.to_owned(),
def: Some(Box::new(Node {
node: Some(NodeEnum::ColumnDef(Box::new(ColumnDef {
type_name: Some(TypeName {
names: vec![
Node {
node: Some(NodeEnum::String(PgString {
sval: "pg_catalog".to_owned(),
})),
},
Node {
node: Some(NodeEnum::String(PgString {
sval: "int8".to_owned(),
})),
},
],
typemod: -1,
..Default::default()
}),
..Default::default()
}))),
})),
behavior: pg_query::protobuf::DropBehavior::DropRestrict.into(),
..Default::default()
}))),
}],
objtype: ObjectType::ObjectTable.into(),
..Default::default()
};
let sql = deparse_node(NodeEnum::AlterTableStmt(alter_stmt))?;
result.push(Statement::Other {
sql,
idempotent: true,
});
}
}

Ok(result)
}

Expand Down Expand Up @@ -708,4 +864,32 @@ ALTER TABLE ONLY public.users
let statements = output.statements(SyncState::PostData).unwrap();
assert!(statements.is_empty());
}

#[test]
fn test_bigint_rewrite() {
let query = r#"
CREATE TABLE test (id INTEGER, value TEXT);
ALTER TABLE test ADD CONSTRAINT id_pkey PRIMARY KEY (id);"#;

let output = PgDumpOutput {
stmts: parse(query).unwrap().protobuf,
original: query.to_owned(),
};

let statements = output.statements(SyncState::PreData).unwrap();
assert_eq!(statements.len(), 3);

assert_eq!(
statements[0].deref(),
"CREATE TABLE IF NOT EXISTS test (id int, value text)"
);
assert_eq!(
statements[1].deref(),
"\nALTER TABLE test ADD CONSTRAINT id_pkey PRIMARY KEY (id)"
);
assert_eq!(
statements[2].deref(),
"ALTER TABLE public.test ALTER COLUMN id TYPE bigint"
);
}
}
2 changes: 1 addition & 1 deletion pgdog/src/frontend/router/parser/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::{Error, Table};
use crate::util::escape_identifier;

/// Column name extracted from a query.
#[derive(Debug, Clone, Copy, PartialEq, Default)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct Column<'a> {
/// Column name.
pub name: &'a str,
Expand Down
Loading