Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sqlmesh/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,5 @@
HYBRID = "hybrid"

DISABLE_SQLMESH_STATE_MIGRATION = "SQLMESH__AIRFLOW__DISABLE_STATE_MIGRATION"

LIQUID_CLUSTERING_KEYWORDS: frozenset = frozenset({"AUTO", "NONE"})
28 changes: 25 additions & 3 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Token

from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE
from sqlmesh.core.constants import LIQUID_CLUSTERING_KEYWORDS, MAX_MODEL_DEFINITION_SIZE
from sqlmesh.utils import get_source_columns_to_types
from sqlmesh.utils.errors import SQLMeshError, ConfigError
from sqlmesh.utils.pandas import columns_to_types_from_df
Expand Down Expand Up @@ -663,6 +663,27 @@ def parse(self: Parser) -> t.Optional[exp.Expr]:
value = exp.tuple_(*partitioned_by.this.expressions)
else:
value = partitioned_by.this
elif key == "clustered_by":
# Bare AUTO / NONE are Databricks liquid clustering keywords, not column refs.
# Detect keywords by token type: unquoted bare identifiers arrive as VAR tokens.
# Backtick-quoted identifiers (e.g. `auto`) have IDENTIFIER token type and are
# treated as real column names.
if (
self._curr is not None
and self._curr.token_type == TokenType.VAR
and self._curr.text.upper() in LIQUID_CLUSTERING_KEYWORDS
):
value = exp.Var(this=self._curr.text.upper())
self._advance()
else:
parsed = self._parse_bracket(self._parse_field(any_token=True))
# Unwrap Paren wrapping a bare column to match partitioned_by normalisation:
# clustered_by (a) → stored as Column(a), not Paren(Column(a)).
# Preserve parens around function expressions: (TO_DATE(col)) stays as-is.
if isinstance(parsed, exp.Paren) and isinstance(parsed.this, exp.Column):
value = parsed.unnest()
else:
value = parsed
else:
value = self._parse_bracket(self._parse_field(any_token=True))

Expand Down Expand Up @@ -1096,8 +1117,9 @@ def extend_sqlglot() -> None:
DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}",
Jinja: lambda self, e: e.name,
JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};",
JinjaStatement: lambda self,
e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};",
JinjaStatement: lambda self, e: (
f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};"
),
VirtualUpdateStatement: lambda self, e: _on_virtual_update_sql(self, e),
MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})",
MacroFunc: _macro_func_sql,
Expand Down
15 changes: 11 additions & 4 deletions sqlmesh/core/engine_adapter/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sqlglot import exp

from sqlmesh.core.constants import LIQUID_CLUSTERING_KEYWORDS
from sqlmesh.core.dialect import to_schema
from sqlmesh.core.engine_adapter.mixins import GrantsFromInfoSchemaMixin
from sqlmesh.core.engine_adapter.shared import (
Expand Down Expand Up @@ -386,10 +387,16 @@ def _build_table_properties_exp(
table_kind=table_kind,
)
if clustered_by:
# Databricks expects wrapped CLUSTER BY expressions
clustered_by_exp = exp.Cluster(
expressions=[exp.Tuple(expressions=[c.copy() for c in clustered_by])]
)
if len(clustered_by) == 1 and isinstance(clustered_by[0], exp.Var):
if clustered_by[0].name.upper() not in LIQUID_CLUSTERING_KEYWORDS:
raise ValueError(f"Unexpected bare Var in clustered_by: {clustered_by[0]!r}")
# exp.Cluster with a bare Var generates: CLUSTER BY AUTO (no parens)
clustered_by_exp = exp.Cluster(expressions=[clustered_by[0].copy()])
else:
# Databricks expects column expressions wrapped in a tuple
clustered_by_exp = exp.Cluster(
expressions=[exp.Tuple(expressions=[c.copy() for c in clustered_by])]
)
expressions = properties.expressions if properties else []
expressions.append(clustered_by_exp)
properties = exp.Properties(expressions=expressions)
Expand Down
6 changes: 6 additions & 0 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,12 @@ def validate_definition(self) -> None:
values = [
col.name
for expr in values
if not (
field == "clustered_by"
and (self.dialect or "").lower() == "databricks"
and isinstance(expr, exp.Var)
and expr.name.upper() in c.LIQUID_CLUSTERING_KEYWORDS
)
for col in t.cast(
exp.Expr, exp.maybe_parse(expr, dialect=self.dialect)
).find_all(exp.Column)
Expand Down
30 changes: 29 additions & 1 deletion sqlmesh/core/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from sqlmesh.core import dialect as d
from sqlmesh.core.config.common import VirtualEnvironmentMode
from sqlmesh.core.constants import LIQUID_CLUSTERING_KEYWORDS
from sqlmesh.core.config.linter import LinterConfig
from sqlmesh.core.dialect import normalize_model_name
from sqlmesh.utils import classproperty
Expand Down Expand Up @@ -190,10 +191,13 @@ def _gateway_validator(cls, v: t.Any) -> t.Optional[str]:

@field_validator("partitioned_by_", "clustered_by", mode="before")
def _partition_and_cluster_validator(cls, v: t.Any, info: ValidationInfo) -> t.List[exp.Expr]:
field = info.field_name or ""
dialect = (get_dialect(info) or "").lower()

if (
isinstance(v, list)
and all(isinstance(i, str) for i in v)
and (info.field_name or "") == "partitioned_by_"
and field == "partitioned_by_"
):
# this branch gets hit when we are deserializing from json because `partitioned_by` is stored as a List[str]
# however, we should only invoke this if the list contains strings because this validator is also
Expand All @@ -206,9 +210,33 @@ def _partition_and_cluster_validator(cls, v: t.Any, info: ValidationInfo) -> t.L
)
v = parsed.this.expressions if isinstance(parsed.this, exp.Schema) else v

if isinstance(v, str) and field == "clustered_by":
v = [v]

if isinstance(v, list) and field == "clustered_by" and dialect == "databricks":
# When deserializing from JSON, clustered_by is stored as List[str].
# Restore keyword sentinels (AUTO/NONE) before list_of_fields_validator normalises
# them into quoted columns.
v = [
exp.Var(this=item.upper())
if isinstance(item, str) and item.upper() in LIQUID_CLUSTERING_KEYWORDS
else item
for item in v
]

expressions = list_of_fields_validator(v, validation_data(info))

for expression in expressions:
# AUTO and NONE are Databricks liquid clustering keywords, not column references.
# Only skip for clustered_by with the Databricks dialect — meaningless elsewhere.
if (
field == "clustered_by"
and dialect == "databricks"
and isinstance(expression, exp.Var)
and expression.name.upper() in LIQUID_CLUSTERING_KEYWORDS
):
continue

num_cols = len(list(expression.find_all(exp.Column)))

error_msg: t.Optional[str] = None
Expand Down
25 changes: 25 additions & 0 deletions tests/core/engine_adapter/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,31 @@ def test_create_table_clustered_by(mocker: MockFixture, make_mocked_engine_adapt
]


@pytest.mark.parametrize("keyword", ["AUTO", "NONE"])
def test_create_table_clustered_by_keyword(
keyword: str, mocker: MockFixture, make_mocked_engine_adapter: t.Callable
):
mocker.patch(
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
)
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")

columns_to_types = {
"cola": exp.DataType.build("INT"),
"colb": exp.DataType.build("TEXT"),
}
adapter.create_table(
"test_table",
columns_to_types,
clustered_by=[exp.Var(this=keyword)],
)

sql_calls = to_sql_calls(adapter)
assert sql_calls == [
f"CREATE TABLE IF NOT EXISTS `test_table` (`cola` INT, `colb` STRING) CLUSTER BY {keyword}",
]


def test_get_data_objects_distinguishes_view_types(mocker):
adapter = DatabricksEngineAdapter(lambda: None, default_catalog="test_catalog")

Expand Down
143 changes: 143 additions & 0 deletions tests/core/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,149 @@ def test_sqlglot_extended_correctly(dialect: str) -> None:
assert ast.sql(dialect=dialect) == "MODEL (\nname foo\n)"


def test_format_model_expressions_clustered_by():
# Unquoted AUTO / NONE → formatted without backticks or parens
for keyword in ("AUTO", "NONE"):
assert format_model_expressions(
parse(
f"""
MODEL (
name db.test,
kind FULL,
dialect databricks,
clustered_by {keyword}
);
SELECT 1 AS a
"""
)
) == (
f"MODEL (\n"
f" name db.test,\n"
f" kind FULL,\n"
f" dialect databricks,\n"
f" clustered_by {keyword}\n"
f");\n\nSELECT\n 1 AS a"
)

# Backtick-quoted `auto` / `none` → treated as a column, rendered quoted
for name in ("auto", "none"):
assert format_model_expressions(
parse(
f"""
MODEL (
name db.test,
kind FULL,
dialect databricks,
clustered_by `{name}`
);
SELECT 1 AS `{name}`
"""
)
) == (
f"MODEL (\n"
f" name db.test,\n"
f" kind FULL,\n"
f" dialect databricks,\n"
f' clustered_by "{name}"\n'
f');\n\nSELECT\n 1 AS "{name}"'
)

# Parens-wrapped (auto) → treated as a column, parens stripped for single column
# (same normalisation as partitioned_by (a) → a); quoting happens at model-load time
assert format_model_expressions(
parse(
"""
MODEL (
name db.test,
kind FULL,
dialect databricks,
clustered_by (auto)
);
SELECT 1 AS auto
"""
)
) == (
"MODEL (\n"
" name db.test,\n"
" kind FULL,\n"
" dialect databricks,\n"
" clustered_by auto\n"
");\n\nSELECT\n 1 AS auto"
)

# Multi-column → parens preserved, identifiers as-written
# (quoting happens when the model is loaded, not at format time)
assert format_model_expressions(
parse(
"""
MODEL (
name db.test,
kind FULL,
dialect databricks,
clustered_by (a, b)
);
SELECT 1 AS a, 2 AS b
"""
)
) == (
"MODEL (\n"
" name db.test,\n"
" kind FULL,\n"
" dialect databricks,\n"
" clustered_by (a, b)\n"
");\n\nSELECT\n 1 AS a,\n 2 AS b"
)


@pytest.mark.parametrize("keyword", ["AUTO", "NONE"])
def test_format_model_expressions_clustered_by_non_databricks(keyword: str):
"""AUTO/NONE without dialect or with a non-Databricks dialect is parsed as a bare identifier."""
# Without dialect — AUTO/NONE parsed as a plain column name (no special keyword handling)
assert format_model_expressions(
parse(
f"""
MODEL (
name db.test,
kind FULL,
clustered_by {keyword}
);
SELECT 1 AS {keyword.lower()}
"""
)
) == (
f"MODEL (\n"
f" name db.test,\n"
f" kind FULL,\n"
f" clustered_by {keyword}\n"
f");\n\nSELECT\n 1 AS {keyword.lower()}"
)


@pytest.mark.parametrize("keyword", ["AUTO", "NONE"])
def test_format_model_expressions_clustered_by_mixed_list(keyword: str):
"""AUTO/NONE inside a parenthesised list is treated as a regular column name."""
assert format_model_expressions(
parse(
f"""
MODEL (
name db.test,
kind FULL,
dialect databricks,
clustered_by (a, {keyword})
);
SELECT 1 AS a, 2 AS {keyword.lower()}
"""
)
) == (
f"MODEL (\n"
f" name db.test,\n"
f" kind FULL,\n"
f" dialect databricks,\n"
f" clustered_by (a, {keyword})\n"
f");\n\nSELECT\n 1 AS a,\n 2 AS {keyword.lower()}"
)


def test_connected_identifier():
ast = d.parse_one("""SELECT ("x"at time zone 'utc')::timestamp as x""", "redshift")
assert ast.sql("redshift") == """SELECT CAST(("x" AT TIME ZONE 'utc') AS TIMESTAMP) AS x"""
Expand Down
Loading