diff --git a/docs/concepts/models/python_models.md b/docs/concepts/models/python_models.md index 10884ecedf..8809364fba 100644 --- a/docs/concepts/models/python_models.md +++ b/docs/concepts/models/python_models.md @@ -369,6 +369,33 @@ def entrypoint( ) ``` +Blueprint variables can also be used as **column names and column types** in the `columns` dictionary. For example, if each blueprint produces a model with a different set of column names and types, both can be parameterized using the same `@{variable}` syntax: + +```python linenums="1" +import pandas as pd +from sqlmesh import ExecutionContext, model + +@model( + "@{customer}.metrics", + kind="FULL", + blueprints=[ + {"customer": "customer1", "primary_metric": "revenue", "primary_type": "int", "secondary_metric": "cost", "secondary_type": "double"}, + {"customer": "customer2", "primary_metric": "sales", "primary_type": "text", "secondary_metric": "profit", "secondary_type": "double"}, + ], + columns={ + "@{primary_metric}": "@{primary_type}", + "@{secondary_metric}": "@{secondary_type}", + }, +) +def entrypoint(context: ExecutionContext, **kwargs) -> pd.DataFrame: + return pd.DataFrame({ + context.blueprint_var("primary_metric"): [1], + context.blueprint_var("secondary_metric"): [1.5], + }) +``` + +Global variables (defined in the project config) can also be used as column names and types in the same way. + Note the use of curly brace syntax `@{customer}` in the model name above. It is used to ensure SQLMesh can combine the macro variable into the model name identifier correctly - learn more [here](../../concepts/macros/sqlmesh_macros.md#embedding-variables-in-strings). Blueprint variable mappings can also be constructed dynamically, e.g., by using a macro: `blueprints="@gen_blueprints()"`. This is useful in cases where the `blueprints` list needs to be sourced from external sources, such as CSV files. diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 94b8c2f2ad..1aa7ca839a 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -1096,8 +1096,9 @@ def extend_sqlglot() -> None: DColonCast: lambda self, e: f"{self.sql(e, 'this')}::{self.sql(e, 'to')}", Jinja: lambda self, e: e.name, JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};", - JinjaStatement: lambda self, - e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};", + JinjaStatement: lambda self, e: ( + f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};" + ), VirtualUpdateStatement: lambda self, e: _on_virtual_update_sql(self, e), MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})", MacroFunc: _macro_func_sql, diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 9081227cb7..7295714087 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -300,8 +300,8 @@ def _get_source_queries( ) for c in target_columns_to_types ] - query_factory = ( - lambda: exp.Select() + query_factory = lambda: ( + exp.Select() .select(*select_columns) .from_(query_or_df.subquery("select_source_columns")) ) diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 328b763f9f..54a43e0080 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -74,13 +74,13 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs: self.columns = { column_name: ( - column_type - if isinstance(column_type, exp.DataType) + column_type # Column types with macros (containing @) will be validated later after rendering + if isinstance(column_type, exp.DataType) or "@" in column_type else exp.DataType.build( str(column_type), dialect=self.kwargs.get("dialect", self._dialect) ) ) - for column_name, column_type in self.kwargs.pop("columns", {}).items() + for column_name, column_type in self.kwargs.get("columns", {}).items() } def __call__( @@ -196,6 +196,8 @@ def model( if isinstance(rendered_name, exp.Expr): rendered_fields["name"] = rendered_name.sql(dialect=dialect) + rendered_columns = rendered_fields.get("columns") + rendered_defaults = ( render_model_defaults( defaults=defaults, @@ -223,7 +225,7 @@ def model( "default_catalog": default_catalog, "variables": variables, "dialect": dialect, - "columns": self.columns if self.columns else None, + "columns": rendered_columns if rendered_columns else None, "module_path": module_path, "macros": macros, "jinja_macros": jinja_macros, diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 3533d1b669..a059fa4e87 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -2977,7 +2977,15 @@ def render_field_value(value: t.Any) -> t.Any: if isinstance(field_value, dict): rendered_dict = {} for key, value in field_value.items(): - if key in RUNTIME_RENDERED_MODEL_FIELDS: + if field == "columns": + column_name = render_field_value(key) + column_type = render_field_value(value) + # If column_type is an Expr (from rendering macros), convert to string. + # Otherwise, leave it as-is (string) for the validator to parse with the correct dialect. + if isinstance(column_type, exp.Expr): + column_type = column_type.sql(dialect=dialect) + rendered_dict[column_name] = column_type + elif key in RUNTIME_RENDERED_MODEL_FIELDS: rendered_dict[key] = parse_strings_with_macro_refs(value, dialect) elif ( # don't parse kind auto_restatement_cron="@..." kwargs (e.g. @daily) into MacroVar diff --git a/sqlmesh/lsp/reference.py b/sqlmesh/lsp/reference.py index 73c4e5681b..5881e1ece7 100644 --- a/sqlmesh/lsp/reference.py +++ b/sqlmesh/lsp/reference.py @@ -332,8 +332,9 @@ def get_model_find_all_references( # Find the model reference at the cursor position model_at_position = next( filter( - lambda ref: isinstance(ref, ModelReference) - and _position_within_range(position, ref.range), + lambda ref: ( + isinstance(ref, ModelReference) and _position_within_range(position, ref.range) + ), get_model_definitions_for_a_path(lint_context, document_uri), ), None, @@ -486,8 +487,9 @@ def get_macro_find_all_references( # Find the macro reference at the cursor position macro_at_position = next( filter( - lambda ref: isinstance(ref, MacroReference) - and _position_within_range(position, ref.range), + lambda ref: ( + isinstance(ref, MacroReference) and _position_within_range(position, ref.range) + ), get_macro_definitions_for_a_path(lsp_context, document_uri), ), None, @@ -517,9 +519,11 @@ def get_macro_find_all_references( # Get macro references that point to the same macro definition matching_refs = filter( - lambda ref: isinstance(ref, MacroReference) - and ref.path == target_macro_path - and ref.target_range == target_macro_target_range, + lambda ref: ( + isinstance(ref, MacroReference) + and ref.path == target_macro_path + and ref.target_range == target_macro_target_range + ), get_macro_definitions_for_a_path(lsp_context, file_uri), ) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d4279a9c90..ef6b0080d2 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -334,7 +334,9 @@ def test_model_union_query(sushi_context, assert_exp_eq): "@get_date() == '1996-02-10'", "'all'", 3, - lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\nUNION ALL\n{expected_select}\n", + lambda expected_select: ( + f"{expected_select}\nUNION ALL\n{expected_select}\nUNION ALL\n{expected_select}\n" + ), ), # Test case 4: DISTINCT type ( @@ -374,7 +376,9 @@ def test_model_union_query(sushi_context, assert_exp_eq): "", "", 3, - lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n\nUNION ALL\n{expected_select}\n", + lambda expected_select: ( + f"{expected_select}\nUNION ALL\n{expected_select}\n\nUNION ALL\n{expected_select}\n" + ), ), # Test case 9: Missing union type AND condition one table ( @@ -10353,6 +10357,94 @@ def entrypoint(context, *args, **kwargs): assert ctx.fetchdf("SELECT * FROM test_schema2.foo").to_dict() == {"id": {0: 1}} +def test_python_model_blueprint_column_names(tmp_path: Path) -> None: + """Blueprint variables can be used as column names and types in Python model definitions.""" + py_model = tmp_path / "models" / "blueprint_col_names.py" + py_model.parent.mkdir(parents=True, exist_ok=True) + py_model.write_text( + """ +import pandas as pd # noqa: TID253 +from sqlmesh import model + +@model( + "test_schema.@model_name", + blueprints=[ + {"model_name": "hotel_revenue", "col_a": "revenue", "type_a": "int", "col_b": "cost", "type_b": "double"}, + {"model_name": "coffee_sales", "col_a": "sales", "type_a": "bigint", "col_b": "profit", "type_b": "text"}, + ], + kind="FULL", + columns={ + "@{col_a}": "@{type_a}", + "@{col_b}": "@{type_b}", + }, +) +def entrypoint(context, *args, **kwargs): + return pd.DataFrame({ + context.blueprint_var("col_a"): [1], + context.blueprint_var("col_b"): [1.5], + }) + """ + ) + + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), + paths=tmp_path, + ) + assert len(ctx.models) == 2 + + model1 = ctx.get_model("test_schema.hotel_revenue", raise_if_missing=True) + model2 = ctx.get_model("test_schema.coffee_sales", raise_if_missing=True) + + assert model1.columns_to_types_ is not None + assert set(model1.columns_to_types_.keys()) == {"revenue", "cost"} + assert model1.columns_to_types_["revenue"] == exp.DataType.build("int") + assert model1.columns_to_types_["cost"] == exp.DataType.build("double") + + assert model2.columns_to_types_ is not None + assert set(model2.columns_to_types_.keys()) == {"sales", "profit"} + assert model2.columns_to_types_["sales"] == exp.DataType.build("bigint") + assert model2.columns_to_types_["profit"] == exp.DataType.build("text") + + +def test_python_model_variable_column_names(tmp_path: Path) -> None: + """Global variables can be used as column names in Python model definitions.""" + py_model = tmp_path / "models" / "var_col_names.py" + py_model.parent.mkdir(parents=True, exist_ok=True) + py_model.write_text( + """ +import pandas as pd # noqa: TID253 +from sqlmesh import model + +@model( + "test_schema.model", + kind="FULL", + columns={ + "@{metric_col}": "int", + "static_col": "text", + }, +) +def entrypoint(context, *args, **kwargs): + return pd.DataFrame({"revenue": [1], "static_col": ["x"]}) + """ + ) + + ctx = Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"metric_col": "revenue"}, + ), + paths=tmp_path, + ) + assert len(ctx.models) == 1 + + model = ctx.get_model("test_schema.model", raise_if_missing=True) + + assert model.columns_to_types_ is not None + assert set(model.columns_to_types_.keys()) == {"revenue", "static_col"} + assert model.columns_to_types_["revenue"] == exp.DataType.build("int") + assert model.columns_to_types_["static_col"] == exp.DataType.build("text") + + @time_machine.travel("2020-01-01 00:00:00 UTC") def test_dynamic_date_spine_model(assert_exp_eq): @macro() diff --git a/tests/core/test_plan_stages.py b/tests/core/test_plan_stages.py index f93a8a4780..eb3f965761 100644 --- a/tests/core/test_plan_stages.py +++ b/tests/core/test_plan_stages.py @@ -692,8 +692,8 @@ def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]): finalized_ts=to_timestamp("2023-01-02"), ) - state_reader.get_environment.side_effect = ( - lambda name: existing_dev_environment if name == "dev" else existing_prod_environment + state_reader.get_environment.side_effect = lambda name: ( + existing_dev_environment if name == "dev" else existing_prod_environment ) state_reader.get_environments_summary.return_value = [ existing_prod_environment.summary, @@ -857,8 +857,8 @@ def test_build_plan_stages_restatement_dev_does_not_clear_intervals( finalized_ts=to_timestamp("2023-01-02"), ) - state_reader.get_environment.side_effect = ( - lambda name: existing_dev_environment if name == "dev" else existing_prod_environment + state_reader.get_environment.side_effect = lambda name: ( + existing_dev_environment if name == "dev" else existing_prod_environment ) state_reader.get_environments_summary.return_value = [ existing_prod_environment.summary, diff --git a/tests/core/test_selector_native.py b/tests/core/test_selector_native.py index 07cafe095b..e8b6f8a7ad 100644 --- a/tests/core/test_selector_native.py +++ b/tests/core/test_selector_native.py @@ -231,8 +231,8 @@ def test_select_models_expired_environment(mocker: MockerFixture, make_snapshot) ) state_reader_mock = mocker.Mock() - state_reader_mock.get_environment.side_effect = ( - lambda name: prod_env if name == "prod" else dev_env + state_reader_mock.get_environment.side_effect = lambda name: ( + prod_env if name == "prod" else dev_env ) all_snapshots = { @@ -875,8 +875,8 @@ def test_select_models_selected_fqns_fallback(mocker: MockerFixture, make_snapsh ) state_reader_mock = mocker.Mock() - state_reader_mock.get_environment.side_effect = ( - lambda name: fallback_env if name == "prod" else None + state_reader_mock.get_environment.side_effect = lambda name: ( + fallback_env if name == "prod" else None ) state_reader_mock.get_snapshots.return_value = { deleted_model_snapshot.snapshot_id: deleted_model_snapshot,