diff --git a/sqlmesh/core/constants.py b/sqlmesh/core/constants.py index 66dadb0b5d..80eb9aa5a0 100644 --- a/sqlmesh/core/constants.py +++ b/sqlmesh/core/constants.py @@ -96,3 +96,5 @@ HYBRID = "hybrid" DISABLE_SQLMESH_STATE_MIGRATION = "SQLMESH__AIRFLOW__DISABLE_STATE_MIGRATION" + +LIQUID_CLUSTERING_KEYWORDS: frozenset = frozenset({"AUTO", "NONE"}) diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 94b8c2f2ad..b762882659 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -24,7 +24,7 @@ from sqlglot.schema import MappingSchema from sqlglot.tokens import Token -from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE +from sqlmesh.core.constants import LIQUID_CLUSTERING_KEYWORDS, MAX_MODEL_DEFINITION_SIZE from sqlmesh.utils import get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError, ConfigError from sqlmesh.utils.pandas import columns_to_types_from_df @@ -663,6 +663,27 @@ def parse(self: Parser) -> t.Optional[exp.Expr]: value = exp.tuple_(*partitioned_by.this.expressions) else: value = partitioned_by.this + elif key == "clustered_by": + # Bare AUTO / NONE are Databricks liquid clustering keywords, not column refs. + # Detect keywords by token type: unquoted bare identifiers arrive as VAR tokens. + # Backtick-quoted identifiers (e.g. `auto`) have IDENTIFIER token type and are + # treated as real column names. + if ( + self._curr is not None + and self._curr.token_type == TokenType.VAR + and self._curr.text.upper() in LIQUID_CLUSTERING_KEYWORDS + ): + value = exp.Var(this=self._curr.text.upper()) + self._advance() + else: + parsed = self._parse_bracket(self._parse_field(any_token=True)) + # Unwrap Paren wrapping a bare column to match partitioned_by normalisation: + # clustered_by (a) → stored as Column(a), not Paren(Column(a)). + # Preserve parens around function expressions: (TO_DATE(col)) stays as-is. + if isinstance(parsed, exp.Paren) and isinstance(parsed.this, exp.Column): + value = parsed.unnest() + else: + value = parsed else: value = self._parse_bracket(self._parse_field(any_token=True)) @@ -1096,8 +1117,9 @@ def extend_sqlglot() -> None: DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}", Jinja: lambda self, e: e.name, JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};", - JinjaStatement: lambda self, - e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};", + JinjaStatement: lambda self, e: ( + f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};" + ), VirtualUpdateStatement: lambda self, e: _on_virtual_update_sql(self, e), MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})", MacroFunc: _macro_func_sql, diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index dbf38f0b94..321902d0d5 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -6,6 +6,7 @@ from sqlglot import exp +from sqlmesh.core.constants import LIQUID_CLUSTERING_KEYWORDS from sqlmesh.core.dialect import to_schema from sqlmesh.core.engine_adapter.mixins import GrantsFromInfoSchemaMixin from sqlmesh.core.engine_adapter.shared import ( @@ -386,10 +387,16 @@ def _build_table_properties_exp( table_kind=table_kind, ) if clustered_by: - # Databricks expects wrapped CLUSTER BY expressions - clustered_by_exp = exp.Cluster( - expressions=[exp.Tuple(expressions=[c.copy() for c in clustered_by])] - ) + if len(clustered_by) == 1 and isinstance(clustered_by[0], exp.Var): + if clustered_by[0].name.upper() not in LIQUID_CLUSTERING_KEYWORDS: + raise ValueError(f"Unexpected bare Var in clustered_by: {clustered_by[0]!r}") + # exp.Cluster with a bare Var generates: CLUSTER BY AUTO (no parens) + clustered_by_exp = exp.Cluster(expressions=[clustered_by[0].copy()]) + else: + # Databricks expects column expressions wrapped in a tuple + clustered_by_exp = exp.Cluster( + expressions=[exp.Tuple(expressions=[c.copy() for c in clustered_by])] + ) expressions = properties.expressions if properties else [] expressions.append(clustered_by_exp) properties = exp.Properties(expressions=expressions) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 3533d1b669..18a8cbbc56 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -992,6 +992,12 @@ def validate_definition(self) -> None: values = [ col.name for expr in values + if not ( + field == "clustered_by" + and (self.dialect or "").lower() == "databricks" + and isinstance(expr, exp.Var) + and expr.name.upper() in c.LIQUID_CLUSTERING_KEYWORDS + ) for col in t.cast( exp.Expr, exp.maybe_parse(expr, dialect=self.dialect) ).find_all(exp.Column) diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index d5a93c459c..9dfe5b5c3d 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -12,6 +12,7 @@ from sqlmesh.core import dialect as d from sqlmesh.core.config.common import VirtualEnvironmentMode +from sqlmesh.core.constants import LIQUID_CLUSTERING_KEYWORDS from sqlmesh.core.config.linter import LinterConfig from sqlmesh.core.dialect import normalize_model_name from sqlmesh.utils import classproperty @@ -190,10 +191,13 @@ def _gateway_validator(cls, v: t.Any) -> t.Optional[str]: @field_validator("partitioned_by_", "clustered_by", mode="before") def _partition_and_cluster_validator(cls, v: t.Any, info: ValidationInfo) -> t.List[exp.Expr]: + field = info.field_name or "" + dialect = (get_dialect(info) or "").lower() + if ( isinstance(v, list) and all(isinstance(i, str) for i in v) - and (info.field_name or "") == "partitioned_by_" + and field == "partitioned_by_" ): # this branch gets hit when we are deserializing from json because `partitioned_by` is stored as a List[str] # however, we should only invoke this if the list contains strings because this validator is also @@ -206,9 +210,33 @@ def _partition_and_cluster_validator(cls, v: t.Any, info: ValidationInfo) -> t.L ) v = parsed.this.expressions if isinstance(parsed.this, exp.Schema) else v + if isinstance(v, str) and field == "clustered_by": + v = [v] + + if isinstance(v, list) and field == "clustered_by" and dialect == "databricks": + # When deserializing from JSON, clustered_by is stored as List[str]. + # Restore keyword sentinels (AUTO/NONE) before list_of_fields_validator normalises + # them into quoted columns. + v = [ + exp.Var(this=item.upper()) + if isinstance(item, str) and item.upper() in LIQUID_CLUSTERING_KEYWORDS + else item + for item in v + ] + expressions = list_of_fields_validator(v, validation_data(info)) for expression in expressions: + # AUTO and NONE are Databricks liquid clustering keywords, not column references. + # Only skip for clustered_by with the Databricks dialect — meaningless elsewhere. + if ( + field == "clustered_by" + and dialect == "databricks" + and isinstance(expression, exp.Var) + and expression.name.upper() in LIQUID_CLUSTERING_KEYWORDS + ): + continue + num_cols = len(list(expression.find_all(exp.Column))) error_msg: t.Optional[str] = None diff --git a/tests/core/engine_adapter/test_databricks.py b/tests/core/engine_adapter/test_databricks.py index 42cbd287f2..199a2a0fb9 100644 --- a/tests/core/engine_adapter/test_databricks.py +++ b/tests/core/engine_adapter/test_databricks.py @@ -458,6 +458,31 @@ def test_create_table_clustered_by(mocker: MockFixture, make_mocked_engine_adapt ] +@pytest.mark.parametrize("keyword", ["AUTO", "NONE"]) +def test_create_table_clustered_by_keyword( + keyword: str, mocker: MockFixture, make_mocked_engine_adapter: t.Callable +): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + + columns_to_types = { + "cola": exp.DataType.build("INT"), + "colb": exp.DataType.build("TEXT"), + } + adapter.create_table( + "test_table", + columns_to_types, + clustered_by=[exp.Var(this=keyword)], + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + f"CREATE TABLE IF NOT EXISTS `test_table` (`cola` INT, `colb` STRING) CLUSTER BY {keyword}", + ] + + def test_get_data_objects_distinguishes_view_types(mocker): adapter = DatabricksEngineAdapter(lambda: None, default_catalog="test_catalog") diff --git a/tests/core/test_dialect.py b/tests/core/test_dialect.py index 8b5adf82c0..e6922465d1 100644 --- a/tests/core/test_dialect.py +++ b/tests/core/test_dialect.py @@ -760,6 +760,149 @@ def test_sqlglot_extended_correctly(dialect: str) -> None: assert ast.sql(dialect=dialect) == "MODEL (\nname foo\n)" +def test_format_model_expressions_clustered_by(): + # Unquoted AUTO / NONE → formatted without backticks or parens + for keyword in ("AUTO", "NONE"): + assert format_model_expressions( + parse( + f""" + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by {keyword} + ); + SELECT 1 AS a + """ + ) + ) == ( + f"MODEL (\n" + f" name db.test,\n" + f" kind FULL,\n" + f" dialect databricks,\n" + f" clustered_by {keyword}\n" + f");\n\nSELECT\n 1 AS a" + ) + + # Backtick-quoted `auto` / `none` → treated as a column, rendered quoted + for name in ("auto", "none"): + assert format_model_expressions( + parse( + f""" + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by `{name}` + ); + SELECT 1 AS `{name}` + """ + ) + ) == ( + f"MODEL (\n" + f" name db.test,\n" + f" kind FULL,\n" + f" dialect databricks,\n" + f' clustered_by "{name}"\n' + f');\n\nSELECT\n 1 AS "{name}"' + ) + + # Parens-wrapped (auto) → treated as a column, parens stripped for single column + # (same normalisation as partitioned_by (a) → a); quoting happens at model-load time + assert format_model_expressions( + parse( + """ + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by (auto) + ); + SELECT 1 AS auto + """ + ) + ) == ( + "MODEL (\n" + " name db.test,\n" + " kind FULL,\n" + " dialect databricks,\n" + " clustered_by auto\n" + ");\n\nSELECT\n 1 AS auto" + ) + + # Multi-column → parens preserved, identifiers as-written + # (quoting happens when the model is loaded, not at format time) + assert format_model_expressions( + parse( + """ + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by (a, b) + ); + SELECT 1 AS a, 2 AS b + """ + ) + ) == ( + "MODEL (\n" + " name db.test,\n" + " kind FULL,\n" + " dialect databricks,\n" + " clustered_by (a, b)\n" + ");\n\nSELECT\n 1 AS a,\n 2 AS b" + ) + + +@pytest.mark.parametrize("keyword", ["AUTO", "NONE"]) +def test_format_model_expressions_clustered_by_non_databricks(keyword: str): + """AUTO/NONE without dialect or with a non-Databricks dialect is parsed as a bare identifier.""" + # Without dialect — AUTO/NONE parsed as a plain column name (no special keyword handling) + assert format_model_expressions( + parse( + f""" + MODEL ( + name db.test, + kind FULL, + clustered_by {keyword} + ); + SELECT 1 AS {keyword.lower()} + """ + ) + ) == ( + f"MODEL (\n" + f" name db.test,\n" + f" kind FULL,\n" + f" clustered_by {keyword}\n" + f");\n\nSELECT\n 1 AS {keyword.lower()}" + ) + + +@pytest.mark.parametrize("keyword", ["AUTO", "NONE"]) +def test_format_model_expressions_clustered_by_mixed_list(keyword: str): + """AUTO/NONE inside a parenthesised list is treated as a regular column name.""" + assert format_model_expressions( + parse( + f""" + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by (a, {keyword}) + ); + SELECT 1 AS a, 2 AS {keyword.lower()} + """ + ) + ) == ( + f"MODEL (\n" + f" name db.test,\n" + f" kind FULL,\n" + f" dialect databricks,\n" + f" clustered_by (a, {keyword})\n" + f");\n\nSELECT\n 1 AS a,\n 2 AS {keyword.lower()}" + ) + + def test_connected_identifier(): ast = d.parse_one("""SELECT ("x"at time zone 'utc')::timestamp as x""", "redshift") assert ast.sql("redshift") == """SELECT CAST(("x" AT TIME ZONE 'utc') AS TIMESTAMP) AS x""" diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d4279a9c90..1530e0811b 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -334,7 +334,9 @@ def test_model_union_query(sushi_context, assert_exp_eq): "@get_date() == '1996-02-10'", "'all'", 3, - lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\nUNION ALL\n{expected_select}\n", + lambda expected_select: ( + f"{expected_select}\nUNION ALL\n{expected_select}\nUNION ALL\n{expected_select}\n" + ), ), # Test case 4: DISTINCT type ( @@ -374,7 +376,9 @@ def test_model_union_query(sushi_context, assert_exp_eq): "", "", 3, - lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n\nUNION ALL\n{expected_select}\n", + lambda expected_select: ( + f"{expected_select}\nUNION ALL\n{expected_select}\n\nUNION ALL\n{expected_select}\n" + ), ), # Test case 9: Missing union type AND condition one table ( @@ -2183,6 +2187,97 @@ def test_render_definition_partitioned_by(): ) +def test_render_definition_clustered_by(): + # Unquoted AUTO keyword → rendered without backticks or parens + for keyword in ("AUTO", "NONE"): + model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by {keyword} + ); + SELECT 1 AS a + """ + ) + ) + assert model.render_definition()[0].sql(pretty=True) == ( + f"MODEL (\n" + f" name db.test,\n" + f" dialect databricks,\n" + f" kind FULL,\n" + f" clustered_by {keyword}\n" + f")" + ) + + # Backtick-quoted `auto` / `none` → treated as a real column name, rendered quoted + for name in ("auto", "none"): + model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by `{name}` + ); + SELECT 1 AS `{name}` + """ + ) + ) + assert model.render_definition()[0].sql(pretty=True) == ( + f"MODEL (\n" + f" name db.test,\n" + f" dialect databricks,\n" + f" kind FULL,\n" + f' clustered_by "{name}"\n' + f")" + ) + + # Parens-wrapped (AUTO) → treated as a real column name, rendered quoted + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by (auto) + ); + SELECT 1 AS auto + """ + ) + ) + assert model.render_definition()[0].sql(pretty=True) == ( + 'MODEL (\n name db.test,\n dialect databricks,\n kind FULL,\n clustered_by "auto"\n)' + ) + + # Multi-column → rendered with parens, unchanged + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by (a, b) + ); + SELECT 1 AS a, 2 AS b + """ + ) + ) + assert model.render_definition()[0].sql(pretty=True) == ( + "MODEL (\n" + " name db.test,\n" + " dialect databricks,\n" + " kind FULL,\n" + ' clustered_by ("a", "b")\n' + ")" + ) + + def test_render_definition_with_virtual_update_statements(): # model has virtual update statements model = load_sql_based_model( @@ -4050,6 +4145,138 @@ def test_model_normalization(): assert model.clustered_by == [exp.to_column('"A"'), exp.to_column('"B"')] +@pytest.mark.parametrize("keyword", ["AUTO", "NONE"]) +def test_clustered_by_keyword(keyword: str): + # Via SQL DDL + expr = d.parse( + f""" + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by {keyword} + ); + SELECT 1 AS a + """ + ) + model = load_sql_based_model(expr) + assert len(model.clustered_by) == 1 + assert model.clustered_by[0].sql(dialect="databricks").upper() == keyword + model.validate_definition() + + # Via Python API with exp.Var + model2 = create_sql_model( + "db.test", + parse_one("SELECT 1 AS a"), + dialect="databricks", + kind=FullKind(), + clustered_by=exp.Var(this=keyword), + ) + assert len(model2.clustered_by) == 1 + assert model2.clustered_by[0].sql(dialect="databricks").upper() == keyword + model2.validate_definition() + + # Via Python API with a plain string — must not silently become a quoted column + model3 = create_sql_model( + "db.test", + parse_one("SELECT 1 AS a"), + dialect="databricks", + kind=FullKind(), + clustered_by=keyword, + ) + assert len(model3.clustered_by) == 1 + assert isinstance(model3.clustered_by[0], exp.Var) + assert model3.clustered_by[0].name.upper() == keyword + model3.validate_definition() + + +def test_clustered_by_quoted_keyword_column(): + """A backtick-quoted column named `auto` or `none` is a real column, not a keyword.""" + for name in ("auto", "none"): + expr = d.parse( + f""" + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by `{name}` + ); + SELECT 1 AS `{name}` + """ + ) + model = load_sql_based_model(expr) + assert len(model.clustered_by) == 1 + # Must be a Column (quoted identifier), not treated as a keyword + assert isinstance(model.clustered_by[0], exp.Column) + assert model.clustered_by[0].name.lower() == name + model.validate_definition() + + +@pytest.mark.parametrize("keyword", ["AUTO", "NONE"]) +def test_clustered_by_keyword_non_databricks_dialect(keyword: str): + """AUTO/NONE should be rejected for non-Databricks dialects as they are meaningless there.""" + with pytest.raises((ConfigError, Exception)): + model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name db.test, + kind FULL, + dialect duckdb, + clustered_by {keyword} + ); + SELECT 1 AS a + """ + ) + ) + model.validate_definition() + + +@pytest.mark.parametrize("keyword", ["AUTO", "NONE"]) +def test_clustered_by_mixed_list_pins_behaviour(keyword: str): + """clustered_by (a, AUTO) — AUTO alongside a real column is treated as a column named AUTO.""" + expr = d.parse( + f""" + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by (a, {keyword}) + ); + SELECT 1 AS a, 2 AS {keyword.lower()} + """ + ) + model = load_sql_based_model(expr) + # Both entries are real columns (AUTO/NONE inside parens is a column, not a keyword) + assert len(model.clustered_by) == 2 + assert all(isinstance(c_expr, exp.Column) for c_expr in model.clustered_by) + model.validate_definition() + + +@pytest.mark.parametrize("keyword", ["AUTO", "NONE"]) +def test_clustered_by_keyword_serialisation_round_trip(keyword: str): + """exp.Var(AUTO/NONE) must survive JSON serialisation and deserialisation unchanged.""" + model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name db.test, + kind FULL, + dialect databricks, + clustered_by {keyword} + ); + SELECT 1 AS a + """ + ) + ) + model_json = model.json() + deserialized = SqlModel.parse_raw(model_json) + assert deserialized.clustered_by == model.clustered_by + assert len(deserialized.clustered_by) == 1 + assert isinstance(deserialized.clustered_by[0], exp.Var) + assert deserialized.clustered_by[0].name.upper() == keyword + + def test_incremental_unmanaged_validation(): model = create_sql_model( "a",