diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 26abad9ebc..b2f227ec59 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -2,7 +2,7 @@ import logging import typing as t -from collections import defaultdict +from collections import defaultdict, OrderedDict from sqlglot import exp, parse_one from sqlglot.transforms import remove_precision_parameterized_types @@ -891,6 +891,60 @@ def _build_partitioned_by_exp( return exp.PartitionedByProperty(this=this) + def _create_table( + self, + table_name_or_schema: t.Union[exp.Schema, TableName], + expression: t.Optional[exp.Expression], + exists: bool = True, + replace: bool = False, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, + **kwargs: t.Any, + ) -> None: + normalized_properties, connection_property = self._prepare_create_table_properties( + kwargs.get("table_properties"), + kwargs.get("table_format"), + kwargs.get("storage_format"), + ) + kwargs["table_properties"] = normalized_properties + + if connection_property is None: + super()._create_table( + table_name_or_schema, + expression, + exists=exists, + replace=replace, + target_columns_to_types=target_columns_to_types, + table_description=table_description, + column_descriptions=column_descriptions, + table_kind=table_kind, + track_rows_processed=track_rows_processed, + **kwargs, + ) + return + + create_expression = self._build_create_table_exp( + table_name_or_schema, + expression=expression, + exists=exists, + replace=replace, + target_columns_to_types=target_columns_to_types, + table_description=( + table_description + if self.COMMENT_CREATION_TABLE.supports_schema_def and self.comments_enabled + else None + ), + table_kind=table_kind, + **kwargs, + ) + sql = self._to_sql(create_expression) + connection_sql = self._connection_clause_sql(connection_property) + sql = self._inject_connection_clause(sql, connection_sql) + self.execute(sql, track_rows_processed=track_rows_processed) + def _build_table_properties_exp( self, catalog_name: t.Optional[str] = None, @@ -926,12 +980,120 @@ def _build_table_properties_exp( ), ) - properties.extend(self._table_or_view_properties_to_expressions(table_properties)) + if table_properties: + for key, value in table_properties.items(): + properties.append(exp.Property(this=key, value=value.copy())) if properties: return exp.Properties(expressions=properties) return None + def _prepare_create_table_properties( + self, + table_properties: t.Optional[t.Dict[str, exp.Expression]], + table_format: t.Optional[str], + storage_format: t.Optional[str], + ) -> t.Tuple[OrderedDict[str, exp.Expression], t.Optional[exp.Expression]]: + normalized_properties: OrderedDict[str, exp.Expression] = OrderedDict() + connection_property: t.Optional[exp.Expression] = None + + if table_properties: + for key, value in table_properties.items(): + if value is None: + continue + key_lower = key.lower() + if key_lower in {"connection", "with_connection"}: + connection_property = value + continue + # Reinsert properties with the latest casing while preserving order + for existing_key in list(normalized_properties.keys()): + if existing_key.lower() == key_lower: + normalized_properties.pop(existing_key) + break + normalized_properties[key] = value.copy() + + def _get_property(name: str) -> t.Optional[exp.Expression]: + for existing_key, value in normalized_properties.items(): + if existing_key.lower() == name: + return value + return None + + def _set_property(name: str, expression: exp.Expression) -> None: + for existing_key in list(normalized_properties.keys()): + if existing_key.lower() == name: + normalized_properties.pop(existing_key) + break + normalized_properties[name] = expression + + def _has_property(name: str) -> bool: + return any(existing_key.lower() == name for existing_key in normalized_properties) + + normalized_table_format = table_format.lower() if table_format else None + if not normalized_table_format: + existing_table_format = _get_property("table_format") + if isinstance(existing_table_format, exp.Literal) and existing_table_format.is_string: + normalized_table_format = existing_table_format.this.lower() + is_iceberg = normalized_table_format == "iceberg" + + if is_iceberg: + table_format_expression = self._ensure_upper_string_literal( + _get_property("table_format"), + default=normalized_table_format or "iceberg", + ) + _set_property("table_format", table_format_expression) + + file_format_expression = self._ensure_upper_string_literal( + _get_property("file_format"), + default=storage_format or "PARQUET", + ) + _set_property("file_format", file_format_expression) + + if not _has_property("storage_uri"): + raise SQLMeshError( + "BigQuery Iceberg tables require `storage_uri` to be set in physical_properties." + ) + + if connection_property is None: + raise SQLMeshError( + "BigQuery Iceberg tables require a `connection` entry in physical_properties." + ) + + return normalized_properties, connection_property + + def _ensure_upper_string_literal( + self, + expression: t.Optional[exp.Expression], + default: str, + ) -> exp.Expression: + if expression is None: + return exp.Literal.string(default.upper()) + + expression = expression.copy() + if isinstance(expression, exp.Literal) and expression.is_string: + return exp.Literal.string(expression.this.upper()) + return expression + + def _connection_clause_sql(self, connection_expression: exp.Expression) -> str: + expression = connection_expression.copy() + if isinstance(expression, exp.Literal) and expression.is_string: + value = expression.this.strip() + if value.upper() == "DEFAULT": + return "DEFAULT" + return exp.to_identifier(value, quoted=True).sql(dialect=self.dialect) + + return self._to_sql(expression) + + @staticmethod + def _inject_connection_clause(create_sql: str, connection_sql: str) -> str: + parts = create_sql.split("OPTIONS", 1) + if len(parts) == 2: + prefix, suffix = parts + if not prefix.endswith(" "): + prefix = f"{prefix} " + return f"{prefix}WITH CONNECTION {connection_sql} OPTIONS{suffix}" + separator = " " if not create_sql.endswith(" ") else "" + return f"{create_sql}{separator}WITH CONNECTION {connection_sql}" + def _build_column_def( self, col_name: str,