Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2941,6 +2941,15 @@ def bulkcopy(
# Translate parsed connection string into the dict py-core expects.
pycore_context = connstr_to_pycore_params(params)

# Forward the cursor's query timeout to py-core so the bulkcopy
# connection uses the same limit instead of py-core's compiled-in 15s
# default. _timeout is the snapshot taken at cursor creation (same value
# _set_timeout uses); 0 means "no override", so py-core keeps its default.
# type-is-int guard keeps bool/mocked values from leaking through.
connect_timeout = self._timeout
if type(connect_timeout) is int and connect_timeout > 0:
pycore_context["connect_timeout"] = connect_timeout

# Token acquisition — only thing cursor must handle (needs azure-identity SDK)
if self.connection._auth_type:
# Fresh token acquisition for mssql-py-core connection
Expand Down
65 changes: 65 additions & 0 deletions tests/test_020_bulkcopy_auth_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

cursor = Cursor.__new__(Cursor)
cursor._connection = mock_conn
cursor._timeout = 0
cursor.closed = False
cursor.hstmt = None
return cursor
Expand Down Expand Up @@ -108,3 +109,67 @@
assert "access_token" not in captured_context
assert captured_context.get("user_name") == "sa"
assert captured_context.get("password") == "mypwd"


def _capture_bulkcopy_context(cursor):
"""Run bulkcopy with a mocked pycore module and return the captured context."""
captured_context = {}

mock_pycore_cursor = MagicMock()
mock_pycore_cursor.bulkcopy.return_value = {
"rows_copied": 1,
"batch_count": 1,
"elapsed_time": 0.1,
}
mock_pycore_conn = MagicMock()
mock_pycore_conn.cursor.return_value = mock_pycore_cursor

def capture_context(ctx, **kwargs):
captured_context.update(ctx)
return mock_pycore_conn

mock_pycore_module = MagicMock()
mock_pycore_module.PyCoreConnection = capture_context

with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}):
cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10)

return captured_context


class TestBulkcopyConnectTimeout:
"""Verify cursor.bulkcopy forwards the cursor timeout to pycore (issue #626)."""

@patch("mssql_python.cursor.logger")
def test_positive_timeout_forwarded(self, mock_logger):
"""cursor._timeout > 0 ⇒ connect_timeout reaches pycore, overriding 15s."""
mock_logger.is_debug_enabled = False
cursor = _make_cursor("Server=localhost;Database=testdb;UID=sa;PWD=pwd", None)

Check notice

Code scanning / devskim

Accessing localhost could indicate debug code, or could hinder scaling. Note test

Do not leave debug code in production
cursor._timeout = 30

captured = _capture_bulkcopy_context(cursor)

assert captured.get("connect_timeout") == 30

@patch("mssql_python.cursor.logger")
def test_zero_timeout_not_forwarded(self, mock_logger):
"""cursor._timeout == 0 ⇒ no override, pycore keeps its default."""
mock_logger.is_debug_enabled = False
cursor = _make_cursor("Server=localhost;Database=testdb;UID=sa;PWD=pwd", None)

Check notice

Code scanning / devskim

Accessing localhost could indicate debug code, or could hinder scaling. Note test

Do not leave debug code in production
cursor._timeout = 0

captured = _capture_bulkcopy_context(cursor)

assert "connect_timeout" not in captured

@patch("mssql_python.cursor.logger")
def test_uses_cursor_snapshot_not_live_connection(self, mock_logger):
"""timeout is the cursor snapshot; later connection changes don't apply."""
mock_logger.is_debug_enabled = False
cursor = _make_cursor("Server=localhost;Database=testdb;UID=sa;PWD=pwd", None)

Check notice

Code scanning / devskim

Accessing localhost could indicate debug code, or could hinder scaling. Note test

Do not leave debug code in production
cursor._timeout = 45
cursor._connection.timeout = 99 # changed after cursor creation, must be ignored

captured = _capture_bulkcopy_context(cursor)

assert captured.get("connect_timeout") == 45
Loading