Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 164 additions & 2 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import typing as t
from collections import defaultdict
from collections import defaultdict, OrderedDict

from sqlglot import exp, parse_one
from sqlglot.transforms import remove_precision_parameterized_types
Expand Down Expand Up @@ -891,6 +891,60 @@ def _build_partitioned_by_exp(

return exp.PartitionedByProperty(this=this)

def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
expression: t.Optional[exp.Expression],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
table_kind: t.Optional[str] = None,
track_rows_processed: bool = True,
**kwargs: t.Any,
) -> None:
normalized_properties, connection_property = self._prepare_create_table_properties(
kwargs.get("table_properties"),
kwargs.get("table_format"),
kwargs.get("storage_format"),
)
kwargs["table_properties"] = normalized_properties

if connection_property is None:
super()._create_table(
table_name_or_schema,
expression,
exists=exists,
replace=replace,
target_columns_to_types=target_columns_to_types,
table_description=table_description,
column_descriptions=column_descriptions,
table_kind=table_kind,
track_rows_processed=track_rows_processed,
**kwargs,
)
return

create_expression = self._build_create_table_exp(
table_name_or_schema,
expression=expression,
exists=exists,
replace=replace,
target_columns_to_types=target_columns_to_types,
table_description=(
table_description
if self.COMMENT_CREATION_TABLE.supports_schema_def and self.comments_enabled
else None
),
table_kind=table_kind,
**kwargs,
)
sql = self._to_sql(create_expression)
connection_sql = self._connection_clause_sql(connection_property)
sql = self._inject_connection_clause(sql, connection_sql)
self.execute(sql, track_rows_processed=track_rows_processed)

def _build_table_properties_exp(
self,
catalog_name: t.Optional[str] = None,
Expand Down Expand Up @@ -926,12 +980,120 @@ def _build_table_properties_exp(
),
)

properties.extend(self._table_or_view_properties_to_expressions(table_properties))
if table_properties:
for key, value in table_properties.items():
properties.append(exp.Property(this=key, value=value.copy()))

if properties:
return exp.Properties(expressions=properties)
return None

def _prepare_create_table_properties(
self,
table_properties: t.Optional[t.Dict[str, exp.Expression]],
table_format: t.Optional[str],
storage_format: t.Optional[str],
) -> t.Tuple[OrderedDict[str, exp.Expression], t.Optional[exp.Expression]]:
normalized_properties: OrderedDict[str, exp.Expression] = OrderedDict()
connection_property: t.Optional[exp.Expression] = None

if table_properties:
for key, value in table_properties.items():
if value is None:
continue
key_lower = key.lower()
if key_lower in {"connection", "with_connection"}:
connection_property = value
continue
# Reinsert properties with the latest casing while preserving order
for existing_key in list(normalized_properties.keys()):
if existing_key.lower() == key_lower:
normalized_properties.pop(existing_key)
break
normalized_properties[key] = value.copy()

def _get_property(name: str) -> t.Optional[exp.Expression]:
for existing_key, value in normalized_properties.items():
if existing_key.lower() == name:
return value
return None

def _set_property(name: str, expression: exp.Expression) -> None:
for existing_key in list(normalized_properties.keys()):
if existing_key.lower() == name:
normalized_properties.pop(existing_key)
break
normalized_properties[name] = expression

def _has_property(name: str) -> bool:
return any(existing_key.lower() == name for existing_key in normalized_properties)

normalized_table_format = table_format.lower() if table_format else None
if not normalized_table_format:
existing_table_format = _get_property("table_format")
if isinstance(existing_table_format, exp.Literal) and existing_table_format.is_string:
normalized_table_format = existing_table_format.this.lower()
is_iceberg = normalized_table_format == "iceberg"

if is_iceberg:
table_format_expression = self._ensure_upper_string_literal(
_get_property("table_format"),
default=normalized_table_format or "iceberg",
)
_set_property("table_format", table_format_expression)

file_format_expression = self._ensure_upper_string_literal(
_get_property("file_format"),
default=storage_format or "PARQUET",
)
_set_property("file_format", file_format_expression)

if not _has_property("storage_uri"):
raise SQLMeshError(
"BigQuery Iceberg tables require `storage_uri` to be set in physical_properties."
)

if connection_property is None:
raise SQLMeshError(
"BigQuery Iceberg tables require a `connection` entry in physical_properties."
)

return normalized_properties, connection_property

def _ensure_upper_string_literal(
self,
expression: t.Optional[exp.Expression],
default: str,
) -> exp.Expression:
if expression is None:
return exp.Literal.string(default.upper())

expression = expression.copy()
if isinstance(expression, exp.Literal) and expression.is_string:
return exp.Literal.string(expression.this.upper())
return expression

def _connection_clause_sql(self, connection_expression: exp.Expression) -> str:
expression = connection_expression.copy()
if isinstance(expression, exp.Literal) and expression.is_string:
value = expression.this.strip()
if value.upper() == "DEFAULT":
return "DEFAULT"
return exp.to_identifier(value, quoted=True).sql(dialect=self.dialect)

return self._to_sql(expression)

@staticmethod
def _inject_connection_clause(create_sql: str, connection_sql: str) -> str:
parts = create_sql.split("OPTIONS", 1)
if len(parts) == 2:
prefix, suffix = parts
if not prefix.endswith(" "):
prefix = f"{prefix} "
return f"{prefix}WITH CONNECTION {connection_sql} OPTIONS{suffix}"
separator = " " if not create_sql.endswith(" ") else ""
return f"{create_sql}{separator}WITH CONNECTION {connection_sql}"

def _build_column_def(
self,
col_name: str,
Expand Down