Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 238 additions & 7 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,174 @@ def _normalize_time_param(value, c_type):
return None


class _ArrowReader:
"""RecordBatchReader-compatible wrapper that makes ``close()`` actually
release server-side resources.

``pyarrow.RecordBatchReader.from_batches(...)`` returns a reader whose
``close()`` only releases the internal ArrowArrayStream — it does **not**
propagate into the underlying Python generator and does **not** stop the
server-side ODBC cursor. This wrapper closes that gap.

Design (optimized):
* The Python generator backing the reader carries its own ``try/finally``
block — so server-side cleanup runs symmetrically whether the user
exhausts the reader, calls ``close()`` mid-iteration, exits a ``with``
block, or just lets the reader be garbage-collected. ``close()``
itself only has to (a) call ``SQLCancel`` to unblock any fetch in
flight on another thread and (b) close the generator; the
``finally`` clause does the rest.
* ``SQLCancel`` is called *before* ``SQLFreeStmt(SQL_CLOSE)`` so a fetch
running on another thread returns cleanly first. ``SQLCancel`` is
the single ODBC entry point (with the diag-record functions) that the
spec marks as safe to call from a different thread than the one
owning the statement.
* Diagnostics are drained *before* the cursor is closed, so records
produced by a cancelled fetch are not lost; a second drain after
close picks up anything ``SQL_CLOSE`` itself emits.
* Cached ``pyarrow.ArrowInvalid`` avoids per-read imports on the
post-close error path.
* ``__del__`` is guarded against interpreter finalization.

The parent ``Cursor`` is **not** closed; it remains fully usable.
"""

__slots__ = ("_cursor", "_inner", "_generator", "_closed", "_arrow_invalid")

def __init__(
self,
cursor: "Cursor",
inner: "pyarrow.RecordBatchReader",
generator,
arrow_invalid_exc: type,
) -> None:
self._cursor = cursor
self._inner = inner
self._generator = generator
self._closed = False
# Cache the exception class so post-close reads in a hot loop don't
# re-import pyarrow.
self._arrow_invalid = arrow_invalid_exc

# ── Public surface mirroring pyarrow.RecordBatchReader ────────────────

@property
def closed(self) -> bool:
"""True once ``close()`` has been called."""
return self._closed

@property
def schema(self):
"""Schema of the record batches produced by this reader."""
if self._closed:
raise self._arrow_invalid("Reader is closed")
return self._inner.schema

def read_next_batch(self):
if self._closed:
raise self._arrow_invalid("Reader is closed")
return self._inner.read_next_batch()

def __iter__(self):
return self

def __next__(self):
if self._closed:
raise self._arrow_invalid("Reader is closed")
return self._inner.read_next_batch()

def __enter__(self):
if self._closed:
raise self._arrow_invalid("Reader is closed")
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False

def __del__(self):
# Best-effort cleanup if the user never called close() (or a previous
# close() attempt failed to release the generator and left cleanup
# incomplete) and the reader is being garbage-collected. Skip during
# interpreter shutdown — the module globals (pyarrow, ddbc_bindings)
# may already be torn down, and touching native code at that point is
# unsafe.
try:
import sys as _sys

if _sys.is_finalizing():
return
# Retry whenever the generator is still referenced — covers both
# "user never called close()" and "earlier close() raised before
# the generator was released".
if getattr(self, "_generator", None) is not None:
self.close()
except Exception: # pylint: disable=broad-exception-caught
pass

# ── Close implementation ──────────────────────────────────────────────

def close(self) -> None:
"""Synchronously stop fetching, release the server-side cursor, and
reset parent-cursor bookkeeping. Idempotent **and retry-safe**:
if a previous call raised before the generator was released (for
example because another thread was still executing it and
``generator.close()`` raised ``ValueError: generator already
executing``), subsequent calls will pick up where the failed call
left off rather than silently no-op'ing.

Most of the actual cleanup work lives in the generator's ``finally``
clause (see ``Cursor.arrow_reader``); this method just unblocks any
in-flight fetch and closes the generator, which triggers that
``finally`` block.
"""
# Fast path: cleanup already completed on a previous call. We use
# the *generator* reference — not ``_closed`` — as the completion
# marker, because ``_closed`` is flipped early (so racing reads
# raise) and must not by itself disable retry of failed cleanup.
if self._generator is None and self._cursor is None:
self._closed = True
return

# Mark closed first so any racing read raises immediately, even if
# the cleanup steps below fail and we end up retried later.
self._closed = True

# SQLCancel (cross-thread safe) — unblocks a fetch running on another
# thread so that the generator's finally clause can then run
# SQLFreeStmt(SQL_CLOSE) without risking the undefined-behaviour
# window of closing an HSTMT mid-fetch. Safe no-op for an idle stmt.
cursor = self._cursor
if cursor is not None and not cursor.closed and cursor.hstmt is not None:
try:
cursor.hstmt._cancel() # pylint: disable=protected-access
except Exception as e: # pylint: disable=broad-exception-caught
logger.debug("arrow_reader.close: SQLCancel raised: %s", e)

# Close the generator — this raises GeneratorExit inside it, which
# runs the try/finally cleanup block (SQLFreeStmt + diag drain +
# cursor bookkeeping reset). If close() raises and the generator is
# still alive (e.g. another thread is currently executing it), keep
# the reference so a subsequent close() / __del__ can retry; only
# drop refs once the generator is actually dead.
gen = self._generator
if gen is not None:
try:
gen.close()
except Exception as e: # pylint: disable=broad-exception-caught
logger.debug("arrow_reader.close: generator.close raised: %s", e)
if getattr(gen, "gi_frame", None) is not None:
# Generator still alive — leave _generator (and _cursor,
# so the next retry can re-issue SQLCancel) intact.
return
self._generator = None

# Drop strong refs so the wrapper does not extend the lifetime of
# the parent Cursor or the inner pyarrow reader.
self._cursor = None
self._inner = None


class Cursor: # pylint: disable=too-many-instance-attributes,too-many-public-methods
"""
Represents a database cursor, which is used to manage the context of a fetch operation.
Expand Down Expand Up @@ -2705,15 +2873,28 @@ def arrow(self, batch_size: int = 8192) -> "pyarrow.Table":

def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader":
"""
Fetch the result as a pyarrow RecordBatchReader, which yields Record
Batches of the specified size until the current result set is
exhausted.
Fetch the result as a pyarrow-compatible RecordBatchReader, which
yields Record Batches of the specified size until the current result
set is exhausted.

The returned object behaves like ``pyarrow.RecordBatchReader``
(``schema``, ``read_next_batch``, iteration, context manager) but
its ``close()`` is fully effective. Cleanup is driven
by a ``try/finally`` block inside the underlying batch generator, so
the same teardown — ``SQLCancel`` to unblock any in-flight fetch on
another thread, ``SQLFreeStmt(SQL_CLOSE)`` to release the server-side
cursor and locks, draining diagnostics into ``cursor.messages``, and
resetting the parent ``Cursor``'s rownumber / ``rowcount`` state —
runs whether the user (a) exhausts the reader normally, (b) calls
``close()`` mid-iteration, (c) exits a ``with`` block, or (d) just
lets the reader be garbage-collected. The parent ``Cursor`` itself
is **not** closed and can be re-executed. ``close()`` is idempotent.

Args:
batch_size: Size of the Record Batches produced by the reader.

Returns:
A pyarrow RecordBatchReader for the result set.
A pyarrow-compatible RecordBatchReader for the result set.
"""
self._check_closed() # Check if the cursor is closed
pyarrow = self._ensure_pyarrow()
Expand All @@ -2722,11 +2903,61 @@ def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader":
schema_batch = self.arrow_batch(0)
schema = schema_batch.schema

# Capture the parent cursor in a closure cell that the generator
# can null out after cleanup, so a GC'd reader does not keep the
# cursor pinned.
cursor_ref = [self]

def batch_generator():
while (batch := self.arrow_batch(batch_size)).num_rows > 0:
yield batch
try:
while (batch := cursor_ref[0].arrow_batch(batch_size)).num_rows > 0:
yield batch
finally:
# Symmetric server-side teardown — runs on exhaustion,
# GeneratorExit (from close()), or an exception inside the
# body. This is the single canonical cleanup site.
cur = cursor_ref[0]
cursor_ref[0] = None
if cur is None or cur.closed or cur.hstmt is None:
return

return pyarrow.RecordBatchReader.from_batches(schema, batch_generator())
# 1) Drain diagnostics produced by the (possibly cancelled)
# fetch *before* SQL_CLOSE so we don't lose them.
try:
cur.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(cur.hstmt))
except Exception as e: # pylint: disable=broad-exception-caught
logger.debug("arrow_reader cleanup: pre-close diag drain failed: %s", e)

# 2) Release the server-side cursor & locks while keeping the
# HSTMT and prepared plan intact, so the parent Cursor can
# be re-executed.
try:
cur.hstmt._close_cursor() # pylint: disable=protected-access
except Exception as e: # pylint: disable=broad-exception-caught
logger.debug("arrow_reader cleanup: _close_cursor failed: %s", e)

# 3) Drain diagnostics produced by SQL_CLOSE itself. This
# runs unconditionally because SQL_CLOSE can return
# SQL_SUCCESS_WITH_INFO (a *success* code) and still leave
# warning records on the HSTMT diag stack; the previous
# "only on failure" path would silently drop those.
try:
cur.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(cur.hstmt))
except Exception as e: # pylint: disable=broad-exception-caught
logger.debug("arrow_reader cleanup: post-close diag drain failed: %s", e)

# 4) Reset cursor bookkeeping to a clean "no result set"
# state. rowcount becomes -1 to signal that the prior
# result is no longer meaningful.
try:
cur._clear_rownumber() # pylint: disable=protected-access
cur.rowcount = -1
except Exception as e: # pylint: disable=broad-exception-caught
logger.debug("arrow_reader cleanup: bookkeeping reset failed: %s", e)

gen = batch_generator()
inner = pyarrow.RecordBatchReader.from_batches(schema, gen)
return _ArrowReader(self, inner, gen, pyarrow.ArrowInvalid)

def nextset(self) -> Optional[bool]:
"""
Expand Down
32 changes: 31 additions & 1 deletion mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ SQLEndTranFunc SQLEndTran_ptr = nullptr;
SQLFreeHandleFunc SQLFreeHandle_ptr = nullptr;
SQLDisconnectFunc SQLDisconnect_ptr = nullptr;
SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr;
SQLCancelFunc SQLCancel_ptr = nullptr;

// Diagnostic APIs
SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr;
Expand Down Expand Up @@ -1295,6 +1296,7 @@ DriverHandle LoadDriverOrThrowException() {
SQLDisconnect_ptr = GetFunctionPointer<SQLDisconnectFunc>(handle, "SQLDisconnect");
SQLFreeHandle_ptr = GetFunctionPointer<SQLFreeHandleFunc>(handle, "SQLFreeHandle");
SQLFreeStmt_ptr = GetFunctionPointer<SQLFreeStmtFunc>(handle, "SQLFreeStmt");
SQLCancel_ptr = GetFunctionPointer<SQLCancelFunc>(handle, "SQLCancel");

SQLGetDiagRec_ptr = GetFunctionPointer<SQLGetDiagRecFunc>(handle, "SQLGetDiagRecW");

Expand Down Expand Up @@ -1433,6 +1435,31 @@ void SqlHandle::close_cursor() {
}
}

void SqlHandle::cancel() {
// SQLCancel is intentionally lenient: it is a no-op on non-STMT handles,
// already-freed handles, or if the driver does not expose it. This lets
// _ArrowReader.close() call it unconditionally without coordinating with
// the fetch thread. The GIL is released so a blocked fetch thread can
// observe the cancel and return.
if (_type != SQL_HANDLE_STMT || !_handle || _implicitly_freed) {
return;
}
if (!SQLCancel_ptr) {
return;
}
SQLHANDLE h = _handle;
SQLRETURN ret;
{
py::gil_scoped_release release;
ret = SQLCancel_ptr(h);
}
// SQLCancel may return SQL_SUCCESS_WITH_INFO when there was nothing to
// cancel; that is fine. We only throw on hard failure.
if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) {
ThrowStdException("SQLCancel failed");
}
}

SQLRETURN SQLResetStmt_wrap(SqlHandlePtr statementHandle) {
if (!statementHandle || !statementHandle->get()) {
return SQL_INVALID_HANDLE;
Expand Down Expand Up @@ -5833,7 +5860,10 @@ PYBIND11_MODULE(ddbc_bindings, m) {
py::class_<SqlHandle, SqlHandlePtr>(m, "SqlHandle")
.def("free", &SqlHandle::free, "Free the handle")
.def("_close_cursor", &SqlHandle::close_cursor,
"Internal: close the cursor without freeing the prepared statement");
"Internal: close the cursor without freeing the prepared statement")
.def("_cancel", &SqlHandle::cancel,
"Internal: cancel an in-progress statement (SQLCancel). "
"Safe to call from another thread; no-op if unsupported or idle.");

py::class_<ConnectionHandle>(m, "Connection")
.def(py::init<const std::u16string&, bool, const py::dict&>(), py::arg("conn_str"),
Expand Down
14 changes: 14 additions & 0 deletions mssql_python/pybind/ddbc_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ typedef SQLRETURN(SQL_API* SQLFreeHandleFunc)(SQLSMALLINT, SQLHANDLE);
typedef SQLRETURN(SQL_API* SQLDisconnectFunc)(SQLHDBC);
typedef SQLRETURN(SQL_API* SQLFreeStmtFunc)(SQLHSTMT, SQLUSMALLINT);

// Cancel API (GH: arrow_reader.close): SQLCancel is one of the two ODBC
// functions guaranteed safe to call from a thread other than the one running
// SQLFetch/SQLExecute, so it is used by _ArrowReader.close() to unblock
// in-flight fetches before SQLFreeStmt(SQL_CLOSE).
typedef SQLRETURN(SQL_API* SQLCancelFunc)(SQLHSTMT);

// Diagnostic APIs
typedef SQLRETURN(SQL_API* SQLGetDiagRecFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT, SQLWCHAR*,
SQLINTEGER*, SQLWCHAR*, SQLSMALLINT, SQLSMALLINT*);
Expand Down Expand Up @@ -171,6 +177,7 @@ extern SQLEndTranFunc SQLEndTran_ptr;
extern SQLFreeHandleFunc SQLFreeHandle_ptr;
extern SQLDisconnectFunc SQLDisconnect_ptr;
extern SQLFreeStmtFunc SQLFreeStmt_ptr;
extern SQLCancelFunc SQLCancel_ptr;

// Diagnostic APIs
extern SQLGetDiagRecFunc SQLGetDiagRec_ptr;
Expand Down Expand Up @@ -257,6 +264,13 @@ class SqlHandle {
SQLSMALLINT type() const;
void free();
void close_cursor();
// Cancel an in-progress statement (SQLCancel). Safe to call from a
// thread other than the one running the fetch — this is the *only*
// ODBC entry point (along with SQLGetDiagField/Rec) for which the spec
// guarantees cross-thread safety. Releases the GIL while calling.
// No-op for non-STMT handles, freed handles, or when the function is
// unavailable.
void cancel();
bool isImplicitlyFreed() const { return _implicitly_freed; }

// Mark this handle as implicitly freed (freed by parent handle)
Expand Down
Loading
Loading