diff --git a/src_py/connection.py b/src_py/connection.py index fa41b89..a481cbf 100644 --- a/src_py/connection.py +++ b/src_py/connection.py @@ -245,17 +245,17 @@ def _normalize_parameters_for_pybind( return normalized_query, normalized_params - def _is_python_scan_object(self, value: Any) -> bool: + @staticmethod + def _is_python_scan_object(value: Any) -> bool: module_name = type(value).__module__ return module_name.startswith(("pandas", "polars", "pyarrow")) - def _has_scan_pattern(self, query: str) -> bool: - stripped = query.lstrip() - if not ( - stripped.upper().startswith("LOAD ") or stripped.upper().startswith("COPY ") - ): - return False - return re.search(r"(?i)\bFROM\b", query) is not None + @staticmethod + def _has_scan_pattern(query: str) -> bool: + matches = re.search( + r"^\s*\b(LOAD|COPY)\b.*?\bFROM\b", query, re.IGNORECASE | re.DOTALL + ) + return matches is not None @staticmethod def _quote_identifier(identifier: str) -> str: @@ -331,7 +331,7 @@ def _rewrite_capi_python_scan( if self._using_pybind_backend() or not self._has_scan_pattern(query): return query, parameters - for key, value in list(parameters.items()): + for key, value in parameters.items(): if not isinstance(key, str): continue match = re.search(rf"(?i)\bFROM\s+(\${re.escape(key)})\b", query) @@ -340,8 +340,8 @@ def _rewrite_capi_python_scan( if not self._is_python_scan_object(value): msg = ( "Binder exception: Trying to scan from unsupported data type " - "INT8[]. The only parameter types that can be scanned from " - "are pandas/polars dataframes and pyarrow tables." + f"{type(value).__name__}. The only parameter types that can " + "be scanned from are pandas/polars dataframes and pyarrow tables." ) raise RuntimeError(msg) if self.database.read_only: @@ -463,9 +463,12 @@ def _execute_with_pybind( prepared = py_connection.prepare(query, parameters) return py_connection.execute(prepared, parameters) - def _maybe_raise_scan_unsupported_object(self, query: str) -> None: + @staticmethod + def _maybe_raise_scan_unsupported_object(query: str) -> None: match = re.search( - r"\bLOAD\s+FROM\s+([A-Za-z_][A-Za-z0-9_]*)\b", query, re.IGNORECASE + r"\b(LOAD|COPY)\b.*?\bFROM\s+([A-Za-z_][A-Za-z0-9_]*)\b", + query, + re.IGNORECASE | re.DOTALL, ) if not match: return @@ -480,19 +483,79 @@ def _maybe_raise_scan_unsupported_object(self, query: str) -> None: return scope = {**caller.f_globals, **caller.f_locals} - if var_name not in scope: - return + value = scope.get(var_name) - value = scope[var_name] - module_name = type(value).__module__ - if module_name.startswith(("pandas", "polars", "pyarrow")): + if value is None or Connection._is_python_scan_object(value): return - msg = ( - "Binder exception: Attempted to scan from unsupported python object. " - "Can only scan from pandas/polars dataframes and pyarrow tables." + raise RuntimeError( + Connection._unsupported_scan_object_parameter_message(var_name, value) + ) + + @staticmethod + def _scan_parameter_names(query: str) -> set[str]: + matches = re.findall( + r"\b(LOAD|COPY)\b.*?\bFROM\s+\$([A-Za-z_][A-Za-z0-9_]*)\b", + query, + re.IGNORECASE | re.DOTALL, + ) + return {match[1] for match in matches} + + @staticmethod + def _unsupported_scan_object_parameter_message(key: str, value: Any) -> str: + return ( + f"Binder exception: Unsupported parameter type {type(value).__name__} " + f"for parameter ${key}. Pandas / polars DataFrames and PyArrow " + "Tables can only be used as LOAD FROM / COPY FROM scan sources." + ) + + @staticmethod + def _capi_prepared_scan_parameter_message() -> str: + return ( + "Binder exception: PreparedStatement with Python dataframe/table scan " + "parameters is not supported on the C-API backend. Use " + "conn.execute(query_string, params) instead." ) - raise RuntimeError(msg) + + @staticmethod + def _prepared_scan_parameter_message(key: str, value: Any) -> str: + return ( + f"Binder exception: Unsupported parameter type {type(value).__name__} " + f"for parameter ${key}. This PreparedStatement does not use ${key} " + "as a LOAD FROM / COPY FROM scan source." + ) + + def _maybe_raise_scan_unsupported_args( + self, + query: str | PreparedStatement, + parameters: dict[str, Any], + *, + for_prepare: bool = False, + ) -> None: + if isinstance(query, str): + is_prepared = False + scan_parameter_names = self._scan_parameter_names(query) + elif isinstance(query, PreparedStatement): + is_prepared = True + scan_parameter_names = query._scan_parameter_names + + supports_scan_parameter_execute = self._using_pybind_backend() + for key, value in parameters.items(): + if not isinstance(key, str): + continue + if not self._is_python_scan_object(value): + continue + if key not in scan_parameter_names: + if is_prepared: + raise RuntimeError( + self._prepared_scan_parameter_message(key, value) + ) + else: + raise RuntimeError( + self._unsupported_scan_object_parameter_message(key, value) + ) + if (for_prepare or is_prepared) and not supports_scan_parameter_execute: + raise RuntimeError(self._capi_prepared_scan_parameter_message()) def execute( self, @@ -532,6 +595,8 @@ def execute( query, parameters = self._rewrite_capi_python_scan(query, parameters) scan_tables_to_drop = self._capi_scan_tables - scan_tables_before + self._maybe_raise_scan_unsupported_args(query, parameters) + if ( not self._using_pybind_backend() and self._query_timeout_ms > 0 @@ -628,6 +693,9 @@ def _prepare( The only parameters supported during prepare are dataframes. Any remaining parameters will be ignored and should be passed to execute(). """ # noqa: D401 + if parameters is None: + parameters = {} + self._maybe_raise_scan_unsupported_args(query, parameters, for_prepare=True) return PreparedStatement(self, query, parameters) def prepare( diff --git a/src_py/prepared_statement.py b/src_py/prepared_statement.py index 25efe9a..4aeb455 100644 --- a/src_py/prepared_statement.py +++ b/src_py/prepared_statement.py @@ -32,6 +32,7 @@ def __init__( parameters = {} self._prepared_statement = connection._connection.prepare(query, parameters) self._connection = connection + self._scan_parameter_names = connection._scan_parameter_names(query) def is_success(self) -> bool: """ diff --git a/test/capi_xfails.py b/test/capi_xfails.py index c928e6f..6970e32 100644 --- a/test/capi_xfails.py +++ b/test/capi_xfails.py @@ -69,6 +69,8 @@ # C API scan rewriting uses temporary Arrow-backed tables, which cannot be # created through a read-only connection. "test/test_scan_pyarrow.py::test_pyarrow_basic", + # + "test/test_scan_parameter.py::test_copy_from_load_scan_object_param_still_works", # UDF registration is still routed through pybind. "test/test_blob_parameter.py::test_bytes_param_udf", "test/test_udf.py::test_udf", diff --git a/test/test_scan_parameter.py b/test/test_scan_parameter.py new file mode 100644 index 0000000..5a2acc3 --- /dev/null +++ b/test/test_scan_parameter.py @@ -0,0 +1,104 @@ +from collections.abc import Callable +from typing import Any + +import pandas as pd +import polars as pl +import pyarrow as pa +import pytest +from type_aliases import ConnDB + +ScanObjectFactory = Callable[[], Any] + + +@pytest.mark.parametrize( + ("make_scan_object", "parameter_name"), + [ + (lambda: pd.DataFrame({"col1": [1, 2, 3]}), "df"), + (lambda: pl.DataFrame({"col1": [1, 2, 3]}), "df"), + (lambda: pa.table({"col1": [1, 2, 3]}), "tab"), + ], +) +def test_scan_object_param_rejected_outside_scan_source( + conn_db_empty: ConnDB, + make_scan_object: ScanObjectFactory, + parameter_name: str, +) -> None: + conn, _ = conn_db_empty + + with pytest.raises( + RuntimeError, + match=rf"Unsupported parameter type .* for parameter \${parameter_name}.*LOAD FROM / COPY FROM", + ): + conn.execute( + f"RETURN ${parameter_name}", + {parameter_name: make_scan_object()}, + ) + + +def test_unreferenced_scan_object_param_is_rejected(conn_db_empty: ConnDB) -> None: + conn, _ = conn_db_empty + df = pd.DataFrame({"col1": [1, 2, 3]}) + + with pytest.raises( + RuntimeError, + match=r"Unsupported parameter type DataFrame for parameter \$df", + ): + conn.execute("RETURN 1", {"df": df}) + + +def test_scan_object_param_with_regular_param_still_works( + conn_db_empty: ConnDB, +) -> None: + conn, _ = conn_db_empty + df = pd.DataFrame({"col1": [1, 2, 3]}) + + result = conn.execute( + "LOAD FROM $df WHERE col1 = $x RETURN col1", + {"df": df, "x": 2}, + ) + + assert result.get_next() == [2] + assert not result.has_next() + + +def test_copy_from_load_scan_object_param_still_works(conn_db_empty: ConnDB) -> None: + conn, _ = conn_db_empty + df = pd.DataFrame({"col1": [10, 20, 30]}) + + conn.execute("CREATE NODE TABLE T(id INT64 PRIMARY KEY)") + conn.execute("COPY T FROM (LOAD FROM $df RETURN col1 AS id)", {"df": df}) + result = conn.execute("MATCH (t:T) RETURN t.id ORDER BY t.id") + + assert result.get_next() == [10] + assert result.get_next() == [20] + assert result.get_next() == [30] + assert not result.has_next() + + +def test_prepare_rejects_scan_object_param_outside_scan_source( + conn_db_empty: ConnDB, +) -> None: + conn, _ = conn_db_empty + df = pd.DataFrame({"col1": [1, 2, 3]}) + + with ( + pytest.warns(DeprecationWarning, match="separate prepare"), + pytest.raises( + RuntimeError, + match=r"Unsupported parameter type DataFrame for parameter \$df", + ), + ): + conn.prepare("RETURN $df", {"df": df}) + + +def test_prepared_execute_rejects_unknown_scan_object_param( + conn_db_empty: ConnDB, +) -> None: + conn, _ = conn_db_empty + df = pd.DataFrame({"col1": [1, 2, 3]}) + + with pytest.warns(DeprecationWarning, match="separate prepare"): + prepared = conn.prepare("RETURN $x") + + with pytest.raises(RuntimeError, match=r"does not use \$df"): + conn.execute(prepared, {"x": 1, "df": df}) diff --git a/test/test_scan_pyarrow.py b/test/test_scan_pyarrow.py index c3d0b92..014ceae 100644 --- a/test/test_scan_pyarrow.py +++ b/test/test_scan_pyarrow.py @@ -137,7 +137,7 @@ def test_pyarrow_copy_from_invalid_source(conn_db_readwrite: ConnDB) -> None: ) with pytest.raises( RuntimeError, - match=r"Binder exception: Trying to scan from unsupported data type INT(8|64)\[\]. The only parameter types that can be scanned from are pandas/polars dataframes and pyarrow tables.", + match=r"Binder exception: Trying to scan from unsupported data type list. The only parameter types that can be scanned from are pandas/polars dataframes and pyarrow tables.", ): conn.execute("COPY pyarrowtab FROM $tab", {"tab": [1, 2, 3]})