Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 90 additions & 22 deletions src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,17 +245,17 @@ def _normalize_parameters_for_pybind(

return normalized_query, normalized_params

def _is_python_scan_object(self, value: Any) -> bool:
@staticmethod
def _is_python_scan_object(value: Any) -> bool:
module_name = type(value).__module__
return module_name.startswith(("pandas", "polars", "pyarrow"))

def _has_scan_pattern(self, query: str) -> bool:
stripped = query.lstrip()
if not (
stripped.upper().startswith("LOAD ") or stripped.upper().startswith("COPY ")
):
return False
return re.search(r"(?i)\bFROM\b", query) is not None
@staticmethod
def _has_scan_pattern(query: str) -> bool:
matches = re.search(
r"^\s*\b(LOAD|COPY)\b.*?\bFROM\b", query, re.IGNORECASE | re.DOTALL
)
return matches is not None

@staticmethod
def _quote_identifier(identifier: str) -> str:
Expand Down Expand Up @@ -331,7 +331,7 @@ def _rewrite_capi_python_scan(
if self._using_pybind_backend() or not self._has_scan_pattern(query):
return query, parameters

for key, value in list(parameters.items()):
for key, value in parameters.items():
if not isinstance(key, str):
continue
match = re.search(rf"(?i)\bFROM\s+(\${re.escape(key)})\b", query)
Expand All @@ -340,8 +340,8 @@ def _rewrite_capi_python_scan(
if not self._is_python_scan_object(value):
msg = (
"Binder exception: Trying to scan from unsupported data type "
"INT8[]. The only parameter types that can be scanned from "

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looked like a typo

"are pandas/polars dataframes and pyarrow tables."
f"{type(value).__name__}. The only parameter types that can "
"be scanned from are pandas/polars dataframes and pyarrow tables."
)
raise RuntimeError(msg)
if self.database.read_only:
Expand Down Expand Up @@ -463,9 +463,12 @@ def _execute_with_pybind(
prepared = py_connection.prepare(query, parameters)
return py_connection.execute(prepared, parameters)

def _maybe_raise_scan_unsupported_object(self, query: str) -> None:
@staticmethod
def _maybe_raise_scan_unsupported_object(query: str) -> None:
match = re.search(
r"\bLOAD\s+FROM\s+([A-Za-z_][A-Za-z0-9_]*)\b", query, re.IGNORECASE
r"\b(LOAD|COPY)\b.*?\bFROM\s+([A-Za-z_][A-Za-z0-9_]*)\b",

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regex should catch all of LOAD, LOAD WITH... and COPY FROM

query,
re.IGNORECASE | re.DOTALL,
)
if not match:
return
Expand All @@ -480,19 +483,79 @@ def _maybe_raise_scan_unsupported_object(self, query: str) -> None:
return

scope = {**caller.f_globals, **caller.f_locals}
if var_name not in scope:
return
value = scope.get(var_name)

value = scope[var_name]
module_name = type(value).__module__
if module_name.startswith(("pandas", "polars", "pyarrow")):
if value is None or Connection._is_python_scan_object(value):
return

msg = (
"Binder exception: Attempted to scan from unsupported python object. "
"Can only scan from pandas/polars dataframes and pyarrow tables."
raise RuntimeError(
Connection._unsupported_scan_object_parameter_message(var_name, value)
)

@staticmethod
def _scan_parameter_names(query: str) -> set[str]:
matches = re.findall(
r"\b(LOAD|COPY)\b.*?\bFROM\s+\$([A-Za-z_][A-Za-z0-9_]*)\b",
query,
re.IGNORECASE | re.DOTALL,
)
return {match[1] for match in matches}

@staticmethod
def _unsupported_scan_object_parameter_message(key: str, value: Any) -> str:
return (
f"Binder exception: Unsupported parameter type {type(value).__name__} "
f"for parameter ${key}. Pandas / polars DataFrames and PyArrow "
"Tables can only be used as LOAD FROM / COPY FROM scan sources."
)

@staticmethod
def _capi_prepared_scan_parameter_message() -> str:
return (
"Binder exception: PreparedStatement with Python dataframe/table scan "
"parameters is not supported on the C-API backend. Use "
"conn.execute(query_string, params) instead."
)
raise RuntimeError(msg)

@staticmethod
def _prepared_scan_parameter_message(key: str, value: Any) -> str:
return (
f"Binder exception: Unsupported parameter type {type(value).__name__} "
f"for parameter ${key}. This PreparedStatement does not use ${key} "
"as a LOAD FROM / COPY FROM scan source."
)

def _maybe_raise_scan_unsupported_args(
self,
query: str | PreparedStatement,
parameters: dict[str, Any],
*,
for_prepare: bool = False,
) -> None:
if isinstance(query, str):
is_prepared = False
scan_parameter_names = self._scan_parameter_names(query)
elif isinstance(query, PreparedStatement):
is_prepared = True
scan_parameter_names = query._scan_parameter_names

supports_scan_parameter_execute = self._using_pybind_backend()
for key, value in parameters.items():
if not isinstance(key, str):
continue
if not self._is_python_scan_object(value):
continue
if key not in scan_parameter_names:
if is_prepared:
raise RuntimeError(
self._prepared_scan_parameter_message(key, value)
)
else:
raise RuntimeError(
self._unsupported_scan_object_parameter_message(key, value)
)
if (for_prepare or is_prepared) and not supports_scan_parameter_execute:
raise RuntimeError(self._capi_prepared_scan_parameter_message())

def execute(
self,
Expand Down Expand Up @@ -532,6 +595,8 @@ def execute(
query, parameters = self._rewrite_capi_python_scan(query, parameters)
scan_tables_to_drop = self._capi_scan_tables - scan_tables_before

self._maybe_raise_scan_unsupported_args(query, parameters)

if (
not self._using_pybind_backend()
and self._query_timeout_ms > 0
Expand Down Expand Up @@ -628,6 +693,9 @@ def _prepare(
The only parameters supported during prepare are dataframes.
Any remaining parameters will be ignored and should be passed to execute().
""" # noqa: D401
if parameters is None:
parameters = {}
self._maybe_raise_scan_unsupported_args(query, parameters, for_prepare=True)
return PreparedStatement(self, query, parameters)

def prepare(
Expand Down
1 change: 1 addition & 0 deletions src_py/prepared_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
parameters = {}
self._prepared_statement = connection._connection.prepare(query, parameters)
self._connection = connection
self._scan_parameter_names = connection._scan_parameter_names(query)

def is_success(self) -> bool:
"""
Expand Down
2 changes: 2 additions & 0 deletions test/capi_xfails.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
# C API scan rewriting uses temporary Arrow-backed tables, which cannot be
# created through a read-only connection.
"test/test_scan_pyarrow.py::test_pyarrow_basic",
#
"test/test_scan_parameter.py::test_copy_from_load_scan_object_param_still_works",
# UDF registration is still routed through pybind.
"test/test_blob_parameter.py::test_bytes_param_udf",
"test/test_udf.py::test_udf",
Expand Down
104 changes: 104 additions & 0 deletions test/test_scan_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from collections.abc import Callable
from typing import Any

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest
from type_aliases import ConnDB

ScanObjectFactory = Callable[[], Any]


@pytest.mark.parametrize(
("make_scan_object", "parameter_name"),
[
(lambda: pd.DataFrame({"col1": [1, 2, 3]}), "df"),
(lambda: pl.DataFrame({"col1": [1, 2, 3]}), "df"),
(lambda: pa.table({"col1": [1, 2, 3]}), "tab"),
],
)
def test_scan_object_param_rejected_outside_scan_source(
conn_db_empty: ConnDB,
make_scan_object: ScanObjectFactory,
parameter_name: str,
) -> None:
conn, _ = conn_db_empty

with pytest.raises(
RuntimeError,
match=rf"Unsupported parameter type .* for parameter \${parameter_name}.*LOAD FROM / COPY FROM",
):
conn.execute(
f"RETURN ${parameter_name}",
{parameter_name: make_scan_object()},
)


def test_unreferenced_scan_object_param_is_rejected(conn_db_empty: ConnDB) -> None:
conn, _ = conn_db_empty
df = pd.DataFrame({"col1": [1, 2, 3]})

with pytest.raises(
RuntimeError,
match=r"Unsupported parameter type DataFrame for parameter \$df",
):
conn.execute("RETURN 1", {"df": df})


def test_scan_object_param_with_regular_param_still_works(
conn_db_empty: ConnDB,
) -> None:
conn, _ = conn_db_empty
df = pd.DataFrame({"col1": [1, 2, 3]})

result = conn.execute(
"LOAD FROM $df WHERE col1 = $x RETURN col1",
{"df": df, "x": 2},
)

assert result.get_next() == [2]
assert not result.has_next()


def test_copy_from_load_scan_object_param_still_works(conn_db_empty: ConnDB) -> None:
conn, _ = conn_db_empty
df = pd.DataFrame({"col1": [10, 20, 30]})

conn.execute("CREATE NODE TABLE T(id INT64 PRIMARY KEY)")
conn.execute("COPY T FROM (LOAD FROM $df RETURN col1 AS id)", {"df": df})
result = conn.execute("MATCH (t:T) RETURN t.id ORDER BY t.id")

assert result.get_next() == [10]
assert result.get_next() == [20]
assert result.get_next() == [30]
assert not result.has_next()


def test_prepare_rejects_scan_object_param_outside_scan_source(
conn_db_empty: ConnDB,
) -> None:
conn, _ = conn_db_empty
df = pd.DataFrame({"col1": [1, 2, 3]})

with (
pytest.warns(DeprecationWarning, match="separate prepare"),
pytest.raises(
RuntimeError,
match=r"Unsupported parameter type DataFrame for parameter \$df",
),
):
conn.prepare("RETURN $df", {"df": df})


def test_prepared_execute_rejects_unknown_scan_object_param(
conn_db_empty: ConnDB,
) -> None:
conn, _ = conn_db_empty
df = pd.DataFrame({"col1": [1, 2, 3]})

with pytest.warns(DeprecationWarning, match="separate prepare"):
prepared = conn.prepare("RETURN $x")

with pytest.raises(RuntimeError, match=r"does not use \$df"):
conn.execute(prepared, {"x": 1, "df": df})
2 changes: 1 addition & 1 deletion test/test_scan_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_pyarrow_copy_from_invalid_source(conn_db_readwrite: ConnDB) -> None:
)
with pytest.raises(
RuntimeError,
match=r"Binder exception: Trying to scan from unsupported data type INT(8|64)\[\]. The only parameter types that can be scanned from are pandas/polars dataframes and pyarrow tables.",
match=r"Binder exception: Trying to scan from unsupported data type list. The only parameter types that can be scanned from are pandas/polars dataframes and pyarrow tables.",
):
conn.execute("COPY pyarrowtab FROM $tab", {"tab": [1, 2, 3]})

Expand Down
Loading