diff --git a/.coveragerc b/.coveragerc index 1182c652..9b922851 100644 --- a/.coveragerc +++ b/.coveragerc @@ -20,6 +20,9 @@ exclude_lines = # Don't complain if non-runnable code isn't run if __name__ == .__main__.: + + # Type-checking-only imports never execute at runtime + if TYPE_CHECKING: # Exclude all logging statements (zero overhead when disabled by design) logger\.debug diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fa44f85..7bdab991 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - New feature: Support for macOS and Linux. - Documentation: Added API documentation in the Wiki. +- New `token_provider=` parameter on `connect()` / `Connection` for Microsoft + Entra ID authentication with a custom credential object. Accepts any object + exposing a `.get_token(scope)` method (e.g. any `azure-identity` credential + such as `DefaultAzureCredential`, `AzureCliCredential`, + `ManagedIdentityCredential`). Mutually exclusive with `Authentication=` in + the connection string and with `attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`. + Bulk copy re-acquires a fresh token from the provider on each operation. The + token scope is fixed to the Azure commercial cloud; sovereign clouds are out + of scope (supply a pre-acquired token via `attrs_before` instead). - Bulk copy now supports `Authentication=ActiveDirectoryServicePrincipal` via an `entra_id_token_factory` callback registered on the mssql-py-core connection. The callback is invoked by mssql-tds mid-handshake (FedAuth diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 4abd2fb9..d8e1359e 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -5,21 +5,27 @@ """ import hashlib +import inspect import platform import struct import threading -from typing import Tuple, Dict, Optional +import time +import warnings +from typing import Tuple, Dict, Optional, Any, TYPE_CHECKING + +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential from mssql_python.logging import logger from mssql_python.constants import ( AuthType, - ConstantsDDBC, _AuthInternal, _KEY_AUTHENTICATION, _KEY_UID, _KEY_PWD, _KEY_TRUSTED_CONNECTION, ) +from mssql_python.exceptions import InterfaceError, OperationalError # Module-level credential instance cache. # Reusing credential objects allows the Azure Identity SDK's built-in @@ -34,6 +40,13 @@ # Canonical keys to strip when handing an Entra-token connection to ODBC. _SENSITIVE_KEYS = frozenset({_KEY_UID, _KEY_PWD, _KEY_TRUSTED_CONNECTION, _KEY_AUTHENTICATION}) +# Azure SQL Database OAuth scope for the Azure **commercial** cloud. Shared by +# the built-in AADAuth path and the custom token_provider path. Sovereign +# clouds (Azure US Gov, Azure China, Azure Germany) are out of scope — a token +# for a different audience is rejected by SQL Server at login. +_DATABASE_SCOPE = "https://database.windows.net/.default" + + # Map Authentication connection-string values to internal short names. _AUTH_TYPE_MAP: Dict[str, str] = { AuthType.INTERACTIVE.value: _AuthInternal.INTERACTIVE, @@ -147,7 +160,7 @@ def _acquire_token( auth_type, ) credential = _credential_cache[cache_key] - raw_token = credential.get_token("https://database.windows.net/.default").token + raw_token = credential.get_token(_DATABASE_SCOPE).token logger.info( "get_token: Azure AD token acquired successfully - token_length=%d chars", len(raw_token), @@ -437,3 +450,168 @@ def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]: """ auth_value = parsed_params.get(_KEY_AUTHENTICATION, "").strip().lower() return _AUTH_TYPE_MAP.get(auth_value) + + +def _get_token_from_credential(credential: "TokenCredential") -> Tuple[str, Optional[int]]: + """Internal: call credential.get_token() and return ``(raw_jwt, expires_on)``. + + Centralises the token-acquisition + error-wrapping logic that both + :func:`acquire_token_from_credential` and + :func:`acquire_raw_token_from_credential` need. + + ``expires_on`` is the POSIX timestamp (seconds) at which the token + expires, taken from the credential's ``AccessToken`` result when present + (it is ``None`` if the provider does not supply one). It is captured so + callers can log it and reason about token lifetime; the access token + itself is a *pre-connect* ODBC attribute and cannot be refreshed on a + live connection (see the module docs on token lifecycle). + + Note: + The scope is hard-coded to the Azure **commercial** cloud + (``https://database.windows.net/.default``). Sovereign clouds + (Azure US Government, Azure China, Azure Germany) are **out of + scope** for the ``token_provider`` path — for those, supply a + pre-acquired token via ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]``. + + Raises: + InterfaceError: If the provider returns no valid ``.token`` string. + OperationalError: If the underlying ``get_token()`` call fails. + """ + start_time = time.perf_counter() + try: + token_result = credential.get_token(_DATABASE_SCOPE) + except TypeError as e: + # get_token() is called with exactly one positional scope argument, so + # a TypeError here almost always means its signature can't accept a + # scope (e.g. a zero-arg or keyword-only get_token). Surface that as a + # clear, actionable InterfaceError instead of an opaque failure. This + # is the call-time source of truth for arity — the connect() path only + # *warns* on a suspicious signature so it never blocks a credential + # whose signature is merely hard to introspect (partial/decorated). + raise InterfaceError( + driver_error=( + "token_provider.get_token() must accept a scope positional " + "argument, e.g. get_token(scope)." + ), + ddbc_error=str(e), + ) from e + except Exception as e: + logger.error( + "_get_token_from_credential: get_token() failed - credential=%s, error=%s", + type(credential).__name__, + str(e), + ) + # Preserve the original credential exception (e.g. azure-identity + # ClientAuthenticationError) as __cause__ for programmatic handling. + raise OperationalError( + driver_error=(f"Failed to acquire token from credential ({type(credential).__name__})"), + ddbc_error=str(e), + ) from e + + # azure.identity.aio (async) credentials return a coroutine from a + # synchronous get_token() call. Detect it and fail with an async-specific + # message rather than tripping over a missing .token attribute — and close + # the coroutine so it doesn't emit a "coroutine was never awaited" warning. + if inspect.iscoroutine(token_result): + token_result.close() + raise InterfaceError( + driver_error=( + "token_provider.get_token() returned a coroutine, which indicates " + "an async credential (e.g. from azure.identity.aio). Use a " + "synchronous credential instead." + ), + ddbc_error=f"got coroutine from {type(credential).__name__}.get_token()", + ) + + raw_token = getattr(token_result, "token", None) + if not isinstance(raw_token, str) or not raw_token: + raise InterfaceError( + driver_error=( + "token_provider.get_token() must return an object with a non-empty " + "string '.token' attribute." + ), + ddbc_error=f"got .token of type {type(raw_token).__name__}", + ) + + expires_on = getattr(token_result, "expires_on", None) + # Warn (don't fail) if the credential handed back an already-expired token: + # the server enforces expiry and will reject the login, so surfacing it here + # points at the real cause instead of an opaque later failure. Only numeric + # POSIX timestamps are checked; bools are excluded to avoid false positives. + if ( + isinstance(expires_on, (int, float)) + and not isinstance(expires_on, bool) + and expires_on < time.time() + ): + warnings.warn( + f"token_provider returned a token that is already expired " + f"(expires_on={expires_on} is in the past). The server will likely " + f"reject the connection.", + UserWarning, + stacklevel=2, + ) + elapsed_ms = (time.perf_counter() - start_time) * 1000 + logger.info( + "_get_token_from_credential: Token acquired from %s - length=%d chars, " + "expires_on=%s, duration_ms=%.2f", + type(credential).__name__, + len(raw_token), + expires_on, + elapsed_ms, + ) + return raw_token, expires_on + + +def acquire_token_from_credential(credential: "TokenCredential") -> Tuple[bytes, Optional[int]]: + """Acquire an ODBC token struct from a user-supplied credential object. + + The credential must follow the Azure ``TokenCredential`` protocol — i.e. + have a ``.get_token(scope)`` method returning an object with a ``.token`` + attribute (a raw JWT string). + + .. note:: + The scope is fixed to the Azure **commercial** cloud + (``https://database.windows.net/.default``). Sovereign clouds (Azure + US Government, Azure China, Azure Germany) are **out of scope** — for + those, supply a pre-acquired token via + ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`` instead. + + Args: + credential: Any object with a ``.get_token(scope)`` method. + + Returns: + Tuple[bytes, Optional[int]]: The ODBC token struct for + ``SQL_COPT_SS_ACCESS_TOKEN`` and the token's ``expires_on`` POSIX + timestamp (``None`` if the provider does not supply one). + + Raises: + InterfaceError: If the provider returns no valid ``.token`` string. + OperationalError: If the underlying ``get_token()`` call fails. + """ + raw_token, expires_on = _get_token_from_credential(credential) + return AADAuth.get_token_struct(raw_token), expires_on + + +def acquire_raw_token_from_credential(credential: "TokenCredential") -> Tuple[str, Optional[int]]: + """Acquire a raw JWT string from a user-supplied credential object. + + Used by bulk copy, which needs the raw JWT rather than the ODBC struct. + + .. note:: + The scope is fixed to the Azure **commercial** cloud. Sovereign + clouds are **out of scope** — see + :func:`acquire_token_from_credential`. + + Args: + credential: Any object with a ``.get_token(scope)`` method. + + Returns: + Tuple[str, Optional[int]]: The raw JWT token string and the token's + ``expires_on`` POSIX timestamp (``None`` if the provider does not + supply one). + + Raises: + InterfaceError: If the provider returns no valid ``.token`` string. + OperationalError: If the underlying ``get_token()`` call fails. + """ + return _get_token_from_credential(credential) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 94fb0924..b02943f0 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -14,6 +14,7 @@ import weakref import re import codecs +import warnings from typing import Any, Dict, Optional, Union, List, Tuple, Callable, TYPE_CHECKING import threading @@ -53,11 +54,14 @@ _RESERVED_PARAMETERS, _KEY_AUTHENTICATION, _KEY_UID, + _KEY_PWD, + _KEY_TRUSTED_CONNECTION, _AuthInternal, ) if TYPE_CHECKING: from mssql_python.row import Row + from azure.core.credentials import TokenCredential # Add SQL_WMETADATA constant for metadata decoding configuration SQL_WMETADATA: int = -99 # Special flag for column name decoding @@ -139,10 +143,10 @@ def _validate_utf16_wchar_compatibility( # Generate context-appropriate error messages if "ctype" in context: - driver_error = f"SQL_WCHAR ctype only supports UTF-16 encodings" + driver_error = "SQL_WCHAR ctype only supports UTF-16 encodings" ddbc_context = "SQL_WCHAR ctype" else: - driver_error = f"SQL_WCHAR only supports UTF-16 encodings" + driver_error = "SQL_WCHAR only supports UTF-16 encodings" ddbc_context = "SQL_WCHAR" raise ProgrammingError( @@ -251,6 +255,7 @@ def __init__( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + token_provider: Optional["TokenCredential"] = None, **kwargs: Any, ) -> None: """ @@ -270,12 +275,53 @@ def __init__( native_uuid (bool, optional): Controls whether UNIQUEIDENTIFIER columns return uuid.UUID objects (True) or str (False) for cursors created from this connection. None (default) defers to the module-level ``mssql_python.native_uuid`` setting (True). + token_provider (object, optional): Advanced token provider for Microsoft Entra ID + authentication. Must expose a callable ``.get_token(scope)`` method that returns + an object with a ``.token`` attribute. + + This parameter is mutually exclusive with ``Authentication=`` in the connection + string and with ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]``; supplying more than + one token source raises ``InterfaceError`` at connect time. + + If ``UID``/``PWD``/``Trusted_Connection`` are also present in the connection + string they are ignored (access-token auth wins) and a warning is emitted. + + .. note:: + The token scope is fixed to the Azure **commercial** cloud + (``https://database.windows.net/.default``). Sovereign clouds (Azure US + Government, Azure China, Azure Germany) are **out of scope** for this + parameter — a token acquired for a different audience is rejected by SQL + Server at login. For sovereign clouds, acquire the token yourself and pass + it via ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`` instead. + + .. note:: + Connection pooling is automatically disabled for any access-token + connection (``token_provider=``, built-in ``Authentication=ActiveDirectory*``, + or a raw ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]``). The native pool keys + on the sanitized connection string only, so different principals sharing the + same server/database would otherwise collide in one pool bucket and could be + handed each other's authenticated connection. Disabling pooling keeps each + principal isolated. + + .. note:: + Token lifecycle limitations: the access token is a *pre-connect* ODBC + attribute, so it cannot be refreshed on a live connection. Long-lived + connections must be recycled by the application once the token nears expiry, + and Continuous Access Evaluation (CAE) claims challenges are not handled. + These require native driver support and are tracked as follow-up work. + Interactive credentials (e.g. ``InteractiveBrowserCredential``) block + ``connect()`` until the user completes sign-in; prefer non-interactive + credentials in server contexts. **kwargs: Additional key/value pairs for the connection string. Returns: None Raises: + InterfaceError: If ``token_provider`` is misused (combined with another token + source, or lacking a valid ``.get_token`` method), or the credential returns + no valid token. + OperationalError: If acquiring a token from ``token_provider`` fails. ValueError: If the connection string is invalid or connection fails. This method sets up the initial state for the connection object, @@ -300,7 +346,11 @@ def __init__( self.connection_str, parsed_params = self._construct_connection_string( connection_str, **kwargs ) - self._attrs_before = attrs_before or {} + # Shallow-copy so we never mutate the caller's dict (e.g. when the + # token_provider path injects SQL_COPT_SS_ACCESS_TOKEN). Mutating the + # caller's object would leak the access token into user state and break + # re-using the same attrs_before dict across multiple connections. + self._attrs_before = dict(attrs_before) if attrs_before else {} # Initialize encoding settings with defaults for Python 3 # Python 3 only has str (which is Unicode), so we use utf-16le by default @@ -339,10 +389,23 @@ def __init__( # fresh token; re-parsing self.connection_str at that point would miss # them because UID is already gone. self._credential_kwargs: Optional[Dict[str, str]] = None + # User-supplied token provider for custom Entra ID authentication. + # Stored so bulk copy can call .get_token() for a fresh JWT later. + self._token_provider: Optional["TokenCredential"] = None + # POSIX timestamp (seconds) at which the current access token expires, + # captured from the credential's AccessToken result. None when unknown. + # The token is a pre-connect ODBC attribute and cannot be refreshed on + # a live connection — this is exposed for diagnostics/logging only. + self._token_expires_on: Optional[int] = None + + # Custom token_provider= parameter — takes priority, mutually exclusive + # with Authentication= in the connection string. + if token_provider is not None: + self._configure_token_provider(token_provider, parsed_params) # Handle Entra ID authentication if specified. # The parsed dict is used directly — no re-parsing of the connection string. - if _KEY_AUTHENTICATION in parsed_params: + elif _KEY_AUTHENTICATION in parsed_params: auth_type = process_auth_parameters(parsed_params) if auth_type: @@ -401,6 +464,26 @@ def __init__( if not PoolingManager.is_initialized(): PoolingManager.enable() self._pooling = PoolingManager.is_enabled() + + # Access-token connections must NOT be pooled. The native pool is keyed + # on the (sanitized) connection string only, and the access token lives + # in attrs_before — which is applied solely when a *new* physical + # connection is created and is never re-applied when a pooled connection + # is reused. With pooling on, two different principals that share the + # same Server/Database collapse into the same pool bucket, so one caller + # can be handed another caller's already-authenticated connection + # (silent identity confusion / privilege escalation). This affects every + # access-token path: a raw SQL_COPT_SS_ACCESS_TOKEN supplied directly in + # attrs_before, built-in Authentication=ActiveDirectory* auth, and the + # token_provider= credential — they all funnel the token through + # attrs_before. Disabling pooling for these connections keeps each + # principal isolated. The same-principal reuse case loses pooling, which + # is an acceptable, correct default. Refreshing the token on a live + # connection (so pooling could be re-enabled safely) needs native driver + # support and is tracked as follow-up work. + if ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in self._attrs_before: + self._pooling = False + try: self._conn = ddbc_bindings.Connection( self.connection_str, self._pooling, self._attrs_before @@ -425,6 +508,80 @@ def __init__( f"Unexpected error during connection registration: {type(e).__name__}: {e}" ) + def _configure_token_provider( + self, token_provider: "TokenCredential", parsed_params: Dict[str, str] + ) -> None: + """Validate a custom ``token_provider`` and apply its access token. + + Acquires a token from ``token_provider.get_token()`` and injects it as + the ``SQL_COPT_SS_ACCESS_TOKEN`` pre-connect attribute, then strips any + sensitive params from the connection string. Mutually exclusive with + ``Authentication=`` and a manual ``attrs_before`` access token. + + Raises: + InterfaceError: If ``token_provider`` is combined with another token + source, or lacks a ``get_token(scope)`` method. + OperationalError: If acquiring a token from ``token_provider`` fails. + """ + if _KEY_AUTHENTICATION in parsed_params: + raise InterfaceError( + driver_error=( + "Cannot specify both 'token_provider' parameter and " + "'Authentication' in the connection string. " + "Use one or the other." + ), + ddbc_error="", + ) + if ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in self._attrs_before: + raise InterfaceError( + driver_error=( + "Cannot specify both 'token_provider' parameter and " + "attrs_before[SQL_COPT_SS_ACCESS_TOKEN]. " + "Use one token source." + ), + ddbc_error="", + ) + get_token = getattr(token_provider, "get_token", None) + if not callable(get_token): + raise InterfaceError( + driver_error=( + f"token_provider must have a .get_token() method. " + f"Got {type(token_provider).__name__}." + ), + ddbc_error="", + ) + # The get_token() signature is NOT inspected here: inspect.signature() + # is unreliable for partial/decorated/C-extension callables and would + # produce false warnings on valid credentials. The actual call is the + # source of truth — _get_token_from_credential turns a bad signature + # (TypeError) into a clear InterfaceError. + from mssql_python.auth import acquire_token_from_credential + + # access-token auth ignores UID/PWD/Trusted_Connection — warn so the + # user is not surprised that those credentials are silently dropped. + dropped = [ + key for key in (_KEY_UID, _KEY_PWD, _KEY_TRUSTED_CONNECTION) if key in parsed_params + ] + if dropped: + warnings.warn( + "token_provider is set, so the following connection-string " + f"credential(s) are ignored: {', '.join(sorted(dropped))}. " + "Remove them to silence this warning.", + UserWarning, + # 3 frames out: warnings.warn -> _configure_token_provider -> + # __init__ -> caller (connect()/Connection()). Keeps the warning + # pointed at user code, not this internal helper. + stacklevel=3, + ) + token, token_expires_on = acquire_token_from_credential(token_provider) + self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token + self._token_provider = token_provider + self._token_expires_on = token_expires_on + # Strip sensitive params (UID/PWD/Trusted_Connection) since + # access-token auth is used — same as the Authentication= path. + sanitized = remove_sensitive_params(parsed_params) + self.connection_str = _ConnectionStringBuilder(sanitized).build() + def _construct_connection_string( self, connection_str: str = "", **kwargs: Any ) -> Tuple[str, Dict[str, str]]: diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 401e434a..01b8d413 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -483,6 +483,16 @@ def get_attribute_set_timing(attribute): _CONNECTION_STRING_DRIVER_KEY = "Driver" _CONNECTION_STRING_APP_KEY = "APP" +_CONNECTION_STRING_AUTH_KEY = "Authentication" +_CONNECTION_STRING_UID_KEY = "UID" +_CONNECTION_STRING_PWD_KEY = "PWD" +_CONNECTION_STRING_TRUSTED_CONNECTION_KEY = "Trusted_Connection" + +# Aliases used by auth.py / connection.py — kept for readability. +_KEY_AUTHENTICATION = _CONNECTION_STRING_AUTH_KEY +_KEY_UID = _CONNECTION_STRING_UID_KEY +_KEY_PWD = _CONNECTION_STRING_PWD_KEY +_KEY_TRUSTED_CONNECTION = _CONNECTION_STRING_TRUSTED_CONNECTION_KEY # Reserved connection string parameters that are controlled by the driver # and cannot be set by users @@ -502,16 +512,16 @@ def get_attribute_set_timing(attribute): "address": "Server", "addr": "Server", # Authentication - "uid": "UID", - "pwd": "PWD", - "authentication": "Authentication", - "trusted_connection": "Trusted_Connection", + "uid": _CONNECTION_STRING_UID_KEY, + "pwd": _CONNECTION_STRING_PWD_KEY, + "authentication": _CONNECTION_STRING_AUTH_KEY, + "trusted_connection": _CONNECTION_STRING_TRUSTED_CONNECTION_KEY, # Database "database": "Database", # Driver (always controlled by mssql-python) - "driver": "Driver", + "driver": _CONNECTION_STRING_DRIVER_KEY, # Application name (always controlled by mssql-python) - "app": "APP", + "app": _CONNECTION_STRING_APP_KEY, # Encryption and Security "encrypt": "Encrypt", "trustservercertificate": "TrustServerCertificate", @@ -535,14 +545,6 @@ def get_attribute_set_timing(attribute): "packetsize": "PacketSize", } -# Canonical normalized key names produced by _ConnectionStringParser._normalize_params. -# Consumer code should reference these instead of hard-coding raw strings so that -# a rename in _ALLOWED_CONNECTION_STRING_PARAMS is caught at import time. -_KEY_AUTHENTICATION = "Authentication" -_KEY_UID = "UID" -_KEY_PWD = "PWD" -_KEY_TRUSTED_CONNECTION = "Trusted_Connection" - def get_info_constants() -> Dict[str, int]: """ diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index aa0eed00..b9e9282c 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2942,7 +2942,27 @@ def bulkcopy( pycore_context = connstr_to_pycore_params(params) # Token acquisition — only thing cursor must handle (needs azure-identity SDK) - if self.connection._auth_type: + if self.connection._token_provider is not None: + # User-supplied credential — use it directly for a fresh token. + from mssql_python.auth import acquire_raw_token_from_credential + + try: + raw_token, _ = acquire_raw_token_from_credential(self.connection._token_provider) + except (OperationalError, InterfaceError) as e: + raise OperationalError( + driver_error=( + "Bulk copy failed: unable to acquire token from custom credential" + ), + ddbc_error=str(e), + ) from e + pycore_context["access_token"] = raw_token + for key in ("authentication", "user_name", "password"): + pycore_context.pop(key, None) + logger.debug( + "Bulk copy: acquired fresh token from custom credential (%s)", + type(self.connection._token_provider).__name__, + ) + elif self.connection._auth_type: # Fresh token acquisition for mssql-py-core connection from mssql_python.auth import AADAuth, ServicePrincipalAuth from mssql_python.constants import _AuthInternal diff --git a/mssql_python/db_connection.py b/mssql_python/db_connection.py index fe10b819..a0955ccd 100644 --- a/mssql_python/db_connection.py +++ b/mssql_python/db_connection.py @@ -4,10 +4,13 @@ This module provides a way to create a new connection object to interact with the database. """ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, TYPE_CHECKING from mssql_python.connection import Connection +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + def connect( connection_str: str = "", @@ -15,6 +18,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + token_provider: Optional["TokenCredential"] = None, **kwargs: Any, ) -> Connection: """ @@ -35,6 +39,34 @@ def connect( This per-connection override is useful for migration from pyodbc: connections that need string UUIDs can pass native_uuid=False, while the default (True) returns native uuid.UUID objects. + token_provider (object, optional): A token provider for Microsoft Entra ID + authentication. This must be any object with a ``.get_token(scope)`` method that + returns an object with a ``.token`` attribute containing a raw JWT string — for + example, any ``azure-identity`` credential class such as + ``DefaultAzureCredential``, ``AzureCliCredential``, ``ManagedIdentityCredential``, + ``CertificateCredential``, etc. + + When provided, the driver calls ``token_provider.get_token()`` to acquire an + access token for SQL Server, bypassing the built-in credential map. + Cannot be combined with ``Authentication=`` in the connection string. + + For environment-portable code, prefer ``Authentication=ActiveDirectoryDefault`` + in the connection string — ``DefaultAzureCredential`` automatically picks the + right credential per environment (CLI on dev, Managed Identity in prod). + Use ``token_provider=`` only when you need explicit control over token + acquisition (e.g., excluding specific providers, using a credential not in + the built-in map, or passing custom options to the credential constructor). + + Example:: + + from azure.identity import AzureCliCredential + conn = mssql_python.connect("Server=s;Database=d", + token_provider=AzureCliCredential()) + + Note: the token scope is fixed to the Azure **commercial** cloud + (``https://database.windows.net/.default``). Sovereign clouds (Azure US + Government, Azure China, Azure Germany) are **out of scope** — acquire the token + yourself and pass it via ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`` instead. Keyword Args: **kwargs: Additional key/value pairs for the connection string. Below attributes are not implemented in the internal driver: @@ -58,6 +90,7 @@ def connect( attrs_before=attrs_before, timeout=timeout, native_uuid=native_uuid, + token_provider=token_provider, **kwargs, ) return conn diff --git a/mssql_python/mssql_python.pyi b/mssql_python/mssql_python.pyi index ad18756e..4e66809d 100644 --- a/mssql_python/mssql_python.pyi +++ b/mssql_python/mssql_python.pyi @@ -9,6 +9,8 @@ import datetime import logging import pyarrow +from azure.core.credentials import TokenCredential + # GLOBALS - DB-API 2.0 Required Module Globals # https://www.python.org/dev/peps/pep-0249/#module-interface apilevel: str # "2.0" @@ -248,6 +250,7 @@ class Connection: attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + token_provider: Optional[TokenCredential] = None, **kwargs: Any, ) -> None: ... @@ -291,6 +294,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + token_provider: Optional[TokenCredential] = None, **kwargs: Any, ) -> Connection: ... diff --git a/requirements.txt b/requirements.txt index 4cd60771..ba886f90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,9 @@ unittest-xml-reporting psutil pyarrow +# Runtime dependencies needed for tests +azure-identity + # Build dependencies pybind11 setuptools diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index d82ecaea..7df49da6 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -5,9 +5,13 @@ """ import pytest +import collections +import inspect import platform import sys import threading +import warnings +from concurrent.futures import ThreadPoolExecutor from unittest.mock import patch, MagicMock from mssql_python.auth import ( AADAuth, @@ -18,9 +22,13 @@ get_auth_token, extract_auth_type, _credential_cache, - _credential_cache_lock, + acquire_token_from_credential, + acquire_raw_token_from_credential, + _DATABASE_SCOPE, ) +from azure.core.credentials import TokenCredential from mssql_python.constants import AuthType, ConstantsDDBC +from mssql_python.exceptions import InterfaceError, OperationalError import secrets SAMPLE_TOKEN = secrets.token_hex(44) @@ -582,6 +590,7 @@ def test_bulkcopy_path_preserves_user_assigned_msi_client_id(self): mock_conn.connection_str = "Server=tcp:test.database.windows.net;Database=testdb;" mock_conn._auth_type = "msi" mock_conn._credential_kwargs = {"client_id": client_id} + mock_conn._token_provider = None mock_conn._is_connected = True cursor = Cursor.__new__(Cursor) @@ -1032,6 +1041,979 @@ def test_token_output_correct_on_cache_miss_and_hit(self): assert "default" in _credential_cache +# ── Custom token_provider= parameter tests ── + + +class TestAcquireTokenFromCredential: + """Tests for the acquire_token_from_credential helper.""" + + def test_happy_path(self): + """acquire_token_from_credential returns a token struct and expiry.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + token_struct, expires_on = acquire_token_from_credential(mock_cred) + assert isinstance(token_struct, bytes) + assert len(token_struct) > 4 + assert expires_on == 1893456000 + mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") + + def test_credential_raises_exception(self): + """acquire_token_from_credential wraps credential errors in OperationalError.""" + mock_cred = MagicMock() + mock_cred.get_token.side_effect = Exception("auth failed") + with pytest.raises(OperationalError, match="Failed to acquire token from credential"): + acquire_token_from_credential(mock_cred) + + def test_missing_token_attribute_raises_interface_error(self): + """Token provider must return an object exposing a non-empty string .token.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = object() + with pytest.raises(InterfaceError, match="non-empty"): + acquire_token_from_credential(mock_cred) + + def test_non_string_token_raises_interface_error(self): + """Token provider must return a .token value of type str.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=123) + with pytest.raises(InterfaceError, match="non-empty"): + acquire_token_from_credential(mock_cred) + + def test_scope_is_commercial_cloud(self): + """The scope is hard-coded to the Azure commercial-cloud audience.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + acquire_token_from_credential(mock_cred) + mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") + + def test_missing_expires_on_returns_none(self): + """A token object without .expires_on yields expires_on=None (not an error).""" + + class MinimalToken: + token = SAMPLE_TOKEN # no expires_on attribute + + mock_cred = MagicMock() + mock_cred.get_token.return_value = MinimalToken() + token_struct, expires_on = acquire_token_from_credential(mock_cred) + assert isinstance(token_struct, bytes) + assert expires_on is None + + def test_bytes_token_raises_interface_error(self): + """A bytes .token (not str) is rejected just like other non-str values.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=b"not_a_str_token") + with pytest.raises(InterfaceError, match="non-empty"): + acquire_token_from_credential(mock_cred) + + def test_whitespace_only_token_is_accepted(self): + """Documents current behavior: a non-empty whitespace token passes the + client-side check (validity is enforced server-side at login).""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=" ", expires_on=None) + token_struct, _ = acquire_token_from_credential(mock_cred) + assert isinstance(token_struct, bytes) + + def test_credential_exception_preserved_as_cause(self): + """The original credential error is chained as __cause__ for callers + that want to catch the underlying azure-identity exception.""" + + class ClientAuthenticationError(Exception): + """Stand-in for azure.core.exceptions.ClientAuthenticationError.""" + + original = ClientAuthenticationError("AADSTS700016") + mock_cred = MagicMock() + mock_cred.get_token.side_effect = original + with pytest.raises(OperationalError) as exc_info: + acquire_token_from_credential(mock_cred) + assert exc_info.value.__cause__ is original + + def test_get_token_returns_none_raises_interface_error(self): + """A credential whose get_token returns None is rejected clearly.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = None + with pytest.raises(InterfaceError, match="non-empty"): + acquire_token_from_credential(mock_cred) + + def test_realistic_length_jwt_round_trips(self): + """A realistic ~1.5 KB JWT is encoded into the ODBC token struct without + truncation (length prefix + UTF-16-LE body).""" + big_jwt = "e" + "A" * 1500 + ".sig" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=big_jwt, expires_on=None) + token_struct, _ = acquire_token_from_credential(mock_cred) + # struct = 4-byte little-endian length prefix + UTF-16-LE token bytes. + expected_body = big_jwt.encode("utf-16-le") + assert token_struct[:4] == len(expected_body).to_bytes(4, "little") + assert token_struct[4:] == expected_body + + +class TestAcquireRawTokenFromCredential: + """Tests for the acquire_raw_token_from_credential helper.""" + + def test_happy_path(self): + """acquire_raw_token_from_credential returns the raw JWT string and expiry.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + raw_token, expires_on = acquire_raw_token_from_credential(mock_cred) + assert raw_token == SAMPLE_TOKEN + assert expires_on == 1893456000 + mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") + + def test_credential_raises_exception(self): + """acquire_raw_token_from_credential wraps credential errors in OperationalError.""" + mock_cred = MagicMock() + mock_cred.get_token.side_effect = Exception("auth failed") + with pytest.raises(OperationalError, match="Failed to acquire token from credential"): + acquire_raw_token_from_credential(mock_cred) + + def test_empty_string_token_raises_interface_error(self): + """Empty token values are rejected as invalid provider output.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token="") + with pytest.raises(InterfaceError, match="non-empty"): + acquire_raw_token_from_credential(mock_cred) + + +class TestCustomTokenProviderConnect: + """Tests for the token_provider= parameter on connect().""" + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_happy_path(self, mock_ddbc_conn): + """token_provider= acquires token and sets attrs_before.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + from mssql_python import connect + + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + assert conn._token_provider is mock_cred + assert conn._token_expires_on == 1893456000 + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + # Existing auth_type should be None (no Authentication= in conn str) + assert conn._auth_type is None + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_plus_authentication_raises_valueerror(self, mock_ddbc_conn): + """token_provider= + Authentication= raises InterfaceError.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with pytest.raises(InterfaceError, match="Cannot specify both"): + connect( + "Server=test;Database=testdb;Authentication=ActiveDirectoryDefault", + token_provider=mock_cred, + ) + mock_cred.get_token.assert_not_called() + mock_ddbc_conn.assert_not_called() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_plus_authentication_via_kwargs_raises_valueerror(self, mock_ddbc_conn): + """token_provider= + Authentication via kwargs raises InterfaceError.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with pytest.raises(InterfaceError, match="Cannot specify both"): + connect( + "Server=test;Database=testdb", + token_provider=mock_cred, + Authentication="ActiveDirectoryDefault", + ) + mock_cred.get_token.assert_not_called() + mock_ddbc_conn.assert_not_called() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_plus_attrs_before_access_token_raises_valueerror(self, mock_ddbc_conn): + """token_provider= + manual attrs_before token is ambiguous and rejected.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with pytest.raises(InterfaceError, match="SQL_COPT_SS_ACCESS_TOKEN"): + connect( + "Server=test;Database=testdb", + token_provider=mock_cred, + attrs_before={ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: b"existing_token"}, + ) + mock_cred.get_token.assert_not_called() + mock_ddbc_conn.assert_not_called() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_without_get_token_raises_typeerror(self, mock_ddbc_conn): + """Passing an object without .get_token() raises InterfaceError.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + with pytest.raises(InterfaceError, match="token_provider must have a .get_token"): + connect("Server=test;Database=testdb", token_provider="not_a_credential") + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_none_uses_existing_flow(self, mock_ddbc_conn): + """token_provider=None (default) uses existing auth flow, no change.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryDefault") + assert conn._token_provider is None + assert conn._auth_type == "default" + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_with_non_auth_attrs_before(self, mock_ddbc_conn): + """token_provider= works alongside non-auth attrs_before.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + login_timeout_attr = 113 # SQL_ATTR_LOGIN_TIMEOUT + conn = connect( + "Server=test;Database=testdb", + token_provider=mock_cred, + attrs_before={login_timeout_attr: 30}, + ) + assert conn._attrs_before[login_timeout_attr] == 30 + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_get_token_failure_raises_runtime_error(self, mock_ddbc_conn): + """If token_provider.get_token() fails, connect() raises OperationalError.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.side_effect = Exception("token acquisition failed") + from mssql_python import connect + + with pytest.raises(OperationalError, match="Failed to acquire token from credential"): + connect("Server=test;Database=testdb", token_provider=mock_cred) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_with_non_callable_get_token_raises_typeerror(self, mock_ddbc_conn): + """Object with .get_token as a non-callable attribute raises InterfaceError.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class BadCredential: + get_token = "not_a_method" + + with pytest.raises(InterfaceError, match="token_provider must have a .get_token"): + connect("Server=test;Database=testdb", token_provider=BadCredential()) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_multiple_connections_share_same_token_provider(self, mock_ddbc_conn): + """Two connections can share the same token provider object safely.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + conn1 = connect("Server=test1;Database=db1", token_provider=mock_cred) + conn2 = connect("Server=test2;Database=db2", token_provider=mock_cred) + assert conn1._token_provider is conn2._token_provider + assert mock_cred.get_token.call_count == 2 + conn1.close() + conn2.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_concurrent_connections_with_same_token_provider(self, mock_ddbc_conn): + """Concurrent connect() calls with one token provider should succeed.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + def _open_and_close(i): + conn = connect(f"Server=test{i};Database=testdb", token_provider=mock_cred) + conn.close() + + with ThreadPoolExecutor(max_workers=8) as executor: + list(executor.map(_open_and_close, range(20))) + + assert mock_cred.get_token.call_count == 20 + + +class TestTokenProviderValidation: + """Tests for token_provider get_token arity validation and the dropped-credential warning.""" + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_scope_is_commercial_cloud(self, mock_ddbc_conn): + """connect() requests the fixed commercial-cloud scope from the credential.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_get_token_wrong_arity_raises_interface_error(self, mock_ddbc_conn): + """A get_token() that cannot accept a scope argument is rejected up-front.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class ZeroArgCredential: + def get_token(self): # missing scope parameter + return MagicMock(token=SAMPLE_TOKEN) + + # No up-front signature inspection: the call-time validation raises. + with pytest.raises(InterfaceError, match="must accept a scope"): + connect("Server=test;Database=testdb", token_provider=ZeroArgCredential()) + mock_ddbc_conn.assert_not_called() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_get_token_with_scope_param_accepted(self, mock_ddbc_conn): + """A well-formed get_token(scope) passes arity validation.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class GoodCredential: + def get_token(self, scope): + return MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + + conn = connect("Server=test;Database=testdb", token_provider=GoodCredential()) + assert conn._token_expires_on == 1893456000 + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_uninspectable_get_token_skips_validation(self, mock_ddbc_conn): + """A get_token whose signature can't be introspected still works (no signature + inspection happens; the real call is the source of truth).""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + from mssql_python import connect + + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + assert conn._token_expires_on == 1893456000 + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_dropped_uid_pwd_emits_warning(self, mock_ddbc_conn): + """UID/PWD in the connection string trigger a warning when token_provider is set.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with pytest.warns(UserWarning, match="credential\\(s\\) are ignored"): + conn = connect( + "Server=test;Database=testdb;UID=user@test.com;PWD=secret", + token_provider=mock_cred, + ) + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_no_warning_without_dropped_credentials(self, mock_ddbc_conn): + """No 'ignored credentials' warning when the connection string has no UID/PWD.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + assert not any("are ignored" in str(w.message) for w in caught) + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_real_azure_style_signature_accepted(self, mock_ddbc_conn): + """get_token(self, *scopes, **kwargs) — the real azure-identity shape — + passes arity validation.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class AzureStyleCredential: + def get_token(self, *scopes, **kwargs): + return MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + + conn = connect("Server=test;Database=testdb", token_provider=AzureStyleCredential()) + assert conn._token_expires_on == 1893456000 + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_connection_string_sanitized_of_uid_pwd(self, mock_ddbc_conn): + """UID/PWD are stripped from connection_str when token_provider is used.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + conn = connect( + "Server=test;Database=testdb;UID=user@test.com;PWD=secret", + token_provider=mock_cred, + ) + assert "UID=" not in conn.connection_str + assert "PWD=" not in conn.connection_str + assert "secret" not in conn.connection_str + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_missing_expires_on_sets_none(self, mock_ddbc_conn): + """A credential whose token lacks .expires_on leaves _token_expires_on None.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class MinimalToken: + token = SAMPLE_TOKEN # no expires_on + + class MinimalCredential: + def get_token(self, scope): + return MinimalToken() + + conn = connect("Server=test;Database=testdb", token_provider=MinimalCredential()) + assert conn._token_expires_on is None + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_dropped_trusted_connection_emits_warning(self, mock_ddbc_conn): + """Trusted_Connection alone also triggers the dropped-credential warning.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with pytest.warns(UserWarning, match="credential\\(s\\) are ignored"): + conn = connect( + "Server=test;Database=testdb;Trusted_Connection=yes", + token_provider=mock_cred, + ) + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_async_credential_coroutine_rejected(self, mock_ddbc_conn): + """An async credential returns a coroutine from a synchronous get_token() + call and is rejected with a clear, async-specific InterfaceError (no + un-awaited-coroutine warning, since the coroutine is closed).""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class AsyncCredential: + async def get_token(self, scope): # azure.identity.aio shape + return MagicMock(token=SAMPLE_TOKEN) + + cred = AsyncCredential() + with pytest.raises(InterfaceError, match="async credential"): + connect("Server=test;Database=testdb", token_provider=cred) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_suspicious_signature_warns_but_does_not_block(self, mock_ddbc_conn): + """A credential with a hard-to-introspect signature (partial/decorated) is + never rejected or warned at connect time — the real call is the source of + truth, so it just succeeds when the call works.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class WorkingCredential: + def get_token(self, scope): # genuinely accepts a scope + return MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + + cred = WorkingCredential() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + conn = connect("Server=test;Database=testdb", token_provider=cred) + assert not any("does not appear to accept" in str(w.message) for w in caught) + # Not blocked: the connection succeeded and captured the token. + assert conn._token_expires_on == 1893456000 + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_keyword_only_scope_rejected(self, mock_ddbc_conn): + """get_token(self, *, scope) can't take scope positionally and is rejected.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class KeywordOnlyCredential: + def get_token(self, *, scope): + return MagicMock(token=SAMPLE_TOKEN) + + # No up-front signature inspection: the call-time validation raises. + with pytest.raises(InterfaceError, match="must accept a scope"): + connect("Server=test;Database=testdb", token_provider=KeywordOnlyCredential()) + mock_ddbc_conn.assert_not_called() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_caller_attrs_before_dict_not_mutated(self, mock_ddbc_conn): + """connect() must not inject the access token into the caller's own + attrs_before dict (it would leak the secret and break dict reuse).""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + login_timeout_attr = 113 # SQL_ATTR_LOGIN_TIMEOUT + caller_opts = {login_timeout_attr: 30} + conn = connect( + "Server=test;Database=testdb", + token_provider=mock_cred, + attrs_before=caller_opts, + ) + # The caller's dict is untouched: no access token leaked in. + assert caller_opts == {login_timeout_attr: 30} + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value not in caller_opts + # The connection's own copy did receive the token. + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_reusing_attrs_before_across_connections_succeeds(self, mock_ddbc_conn): + """The same attrs_before dict can be reused for a second connection with + a different provider — proves the dict isn't polluted by the first.""" + mock_ddbc_conn.return_value = MagicMock() + cred_a = MagicMock() + cred_a.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + cred_b = MagicMock() + cred_b.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + shared_opts = {113: 30} # SQL_ATTR_LOGIN_TIMEOUT + c1 = connect("Server=s;Database=d", token_provider=cred_a, attrs_before=shared_opts) + # Without the copy fix this raises "Cannot specify both ... access token". + c2 = connect("Server=s;Database=d", token_provider=cred_b, attrs_before=shared_opts) + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in c1._attrs_before + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in c2._attrs_before + assert c1._attrs_before is not c2._attrs_before + c1.close() + c2.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_expired_expires_on_warns_but_is_accepted(self, mock_ddbc_conn): + """An already-expired expires_on is still accepted (the server enforces + expiry), but a warning is emitted so the likely cause surfaces early.""" + mock_ddbc_conn.return_value = MagicMock() + past = 1 # 1970 — long expired + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=past) + from mssql_python import connect + + with pytest.warns(UserWarning, match="already expired"): + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + assert conn._token_expires_on == past + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_value_not_in_exception_message(self, mock_ddbc_conn): + """A provider failure must not leak the acquired token in the error.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.side_effect = Exception("auth failed") + from mssql_python import connect + + with pytest.raises(OperationalError) as exc_info: + connect("Server=test;Database=testdb", token_provider=mock_cred) + assert SAMPLE_TOKEN not in str(exc_info.value) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_value_not_in_logs(self, mock_ddbc_conn, caplog): + """The raw JWT must never be written to logs (only its length).""" + import logging + + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with caplog.at_level(logging.DEBUG): + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + assert SAMPLE_TOKEN not in caplog.text + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_empty_connection_string_with_token_provider(self, mock_ddbc_conn): + """An empty connection string with token_provider should not crash the + validation path; the token is still acquired and attached.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + conn = connect("", token_provider=mock_cred) + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + +class TestTokenProviderProtocol: + """Tests for the runtime_checkable azure.core TokenCredential protocol.""" + + def test_object_with_get_token_is_instance(self): + """An object exposing get_token satisfies the Protocol at runtime.""" + + class Cred: + def get_token(self, *scopes, **kwargs): + return MagicMock(token=SAMPLE_TOKEN) + + assert isinstance(Cred(), TokenCredential) + + def test_object_without_get_token_is_not_instance(self): + """An object missing get_token does not satisfy the Protocol.""" + + class NotCred: + def something_else(self): + return None + + assert not isinstance(NotCred(), TokenCredential) + + def test_database_scope_is_commercial_cloud_constant(self): + """The shared scope constant points at the Azure commercial-cloud audience.""" + assert _DATABASE_SCOPE == "https://database.windows.net/.default" + + +class TestTokenProviderPooling: + """Pins pooling behavior for access-token connections. + + The native pool keys on the (sanitized) connection string only, and the + access token lives in attrs_before — applied just once when a *new* physical + connection is created and never re-applied on reuse. So two different + principals that share the same Server/Database would collide in the same + pool bucket and one could be handed another's authenticated connection. + To prevent that silent identity confusion, Connection.__init__ disables + pooling whenever an access token is present in attrs_before. These tests pin + that contract for every access-token path (raw SQL_COPT_SS_ACCESS_TOKEN, + built-in Authentication=ActiveDirectory*, and token_provider=). + """ + + @staticmethod + def _pooling_arg(mock_ddbc_conn): + """Return the `pooling` positional arg passed to ddbc_bindings.Connection.""" + # ddbc_bindings.Connection(connection_str, pooling, attrs_before) + return mock_ddbc_conn.call_args.args[1] + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_disables_pooling(self, mock_ddbc_conn): + """token_provider= connections must not be pooled (cross-principal + collision guard).""" + mock_ddbc_conn.return_value = MagicMock() + cred = MagicMock() + cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + conn = connect("Server=s;Database=d", token_provider=cred) + assert self._pooling_arg(mock_ddbc_conn) is False + assert conn._pooling is False + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_raw_access_token_in_attrs_before_disables_pooling(self, mock_ddbc_conn): + """A raw SQL_COPT_SS_ACCESS_TOKEN supplied directly in attrs_before (the + pyodbc-style path) must also disable pooling — this path was uncovered + before the fix.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + from mssql_python.constants import ConstantsDDBC + + token_struct = b"\x04\x00\x00\x00test" + conn = connect( + "Server=s;Database=d", + attrs_before={ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct}, + ) + assert self._pooling_arg(mock_ddbc_conn) is False + assert conn._pooling is False + conn.close() + + @patch("mssql_python.connection.get_auth_token") + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_builtin_entra_auth_disables_pooling(self, mock_ddbc_conn, mock_get_token): + """Built-in Authentication=ActiveDirectory* auth that injects a token into + attrs_before (e.g. ActiveDirectoryDefault) must also disable pooling — + this path was uncovered before the fix. (Driver-native paths such as + ServicePrincipal keep credentials in the connection string and remain + poolable; see test_builtin_driver_native_auth_keeps_pooling.)""" + mock_ddbc_conn.return_value = MagicMock() + mock_get_token.return_value = b"\x04\x00\x00\x00test" + from mssql_python import connect + + conn = connect("Server=s;Database=d;Authentication=ActiveDirectoryDefault") + assert self._pooling_arg(mock_ddbc_conn) is False + assert conn._pooling is False + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_builtin_driver_native_auth_keeps_pooling(self, mock_ddbc_conn): + """Driver-native Entra auth (ServicePrincipal) keeps UID/PWD in the + connection string, so the pool key already distinguishes principals and + pooling stays enabled — no token is injected into attrs_before.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + from mssql_python.constants import ConstantsDDBC + from mssql_python.pooling import PoolingManager + + PoolingManager._reset_for_testing() + conn = connect( + "Server=s;Database=d;Authentication=ActiveDirectoryServicePrincipal;" + "UID=app-id;PWD=app-secret" + ) + # No access token was injected, so pooling is left enabled. + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value not in conn._attrs_before + assert self._pooling_arg(mock_ddbc_conn) is True + assert conn._pooling is True + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_non_token_connection_keeps_pooling_enabled(self, mock_ddbc_conn): + """A plain connection (no access token) is still eligible for pooling — + the fix must not regress normal SQL/Windows-auth pooling.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + from mssql_python.pooling import PoolingManager + + PoolingManager._reset_for_testing() + conn = connect("Server=s;Database=d;UID=sa;PWD=secret") + assert self._pooling_arg(mock_ddbc_conn) is True + assert conn._pooling is True + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_different_providers_yield_identical_connection_string(self, mock_ddbc_conn): + """Two different providers -> same sanitized connection string. This is + exactly why pooling must be disabled: the pool key (the connection + string) can't tell the principals apart.""" + mock_ddbc_conn.return_value = MagicMock() + cred_a = MagicMock() + cred_a.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + cred_b = MagicMock() + cred_b.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + c1 = connect("Server=s;Database=d", token_provider=cred_a) + c2 = connect("Server=s;Database=d", token_provider=cred_b) + assert c1.connection_str == c2.connection_str + # ...but neither is pooled, so the collision can never occur. + assert c1._pooling is False + assert c2._pooling is False + c1.close() + c2.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_not_refreshed_after_connect(self, mock_ddbc_conn): + """The access token is a pre-connect attribute: it is acquired exactly + once at connect() and not re-acquired for the life of the connection.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1) + from mssql_python import connect + + # expires_on=1 is in the past, so the expired-token warning fires; the + # point of this test is that the token is acquired exactly once. + with pytest.warns(UserWarning, match="already expired"): + conn = connect("Server=s;Database=d", token_provider=mock_cred) + # Even though expires_on is in the past, nothing re-acquires the token. + assert mock_cred.get_token.call_count == 1 + conn.close() + assert mock_cred.get_token.call_count == 1 + + +# --- Faithful azure-identity stand-ins ------------------------------------- +# These mirror the real azure.core.credentials API so the token_provider path +# is exercised exactly as it would be with a live `azure-identity` install, +# without taking a dependency on the package or making network calls. + +# azure.core.credentials.AccessToken is a NamedTuple(token: str, expires_on: int). +_AccessToken = collections.namedtuple("AccessToken", ["token", "expires_on"]) + + +class _FakeDefaultAzureCredential: + """Mirrors azure.identity.DefaultAzureCredential. + + Real signature: + get_token(self, *scopes, claims=None, tenant_id=None, + enable_cae=False, **kwargs) -> AccessToken + The SDK caches internally and hands back the same AccessToken until it is + near expiry, so repeated calls are cheap and return a stable token. + """ + + def __init__(self, token=SAMPLE_TOKEN, expires_on=1893456000): + self._cached = _AccessToken(token, expires_on) + self.calls = [] + + def get_token(self, *scopes, claims=None, tenant_id=None, enable_cae=False, **kwargs): + self.calls.append(scopes) + return self._cached + + +class _FakeClientSecretCredential: + """Mirrors azure.identity.ClientSecretCredential (service principal).""" + + def __init__(self, tenant_id, client_id, client_secret, token=SAMPLE_TOKEN): + self.tenant_id = tenant_id + self.client_id = client_id + self._secret = client_secret + self._token = token + self.calls = 0 + + def get_token(self, *scopes, **kwargs): + self.calls += 1 + return _AccessToken(self._token, 1893456000) + + +class _FakeManagedIdentityCredential: + """Mirrors azure.identity.ManagedIdentityCredential (App Service / VM).""" + + def __init__(self, client_id=None, token=SAMPLE_TOKEN): + self.client_id = client_id + self._token = token + + def get_token(self, *scopes, **kwargs): + return _AccessToken(self._token, 1893456000) + + +class _FakeInteractiveBrowserCredential: + """Mirrors azure.identity.InteractiveBrowserCredential. + + The first call performs an interactive sign-in (slow / may block); after + that the token is cached. We model that the first get_token is the one that + "logs in" and subsequent calls return the cached value. + """ + + def __init__(self, token=SAMPLE_TOKEN): + self._token = token + self.login_count = 0 + + def get_token(self, *scopes, claims=None, tenant_id=None, enable_cae=False, **kwargs): + if self.login_count == 0: + self.login_count += 1 # "interactive sign-in" happens here + return _AccessToken(self._token, 1893456000) + + +class TestTokenProviderRealWorld: + """End-to-end checks against faithful azure-identity credential stand-ins. + + Validates that the token_provider= fix behaves correctly with the real + Azure SDK API shapes (AccessToken namedtuple, *scopes/**kwargs signatures) + and the real usage patterns library consumers actually write. + """ + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_default_azure_credential_end_to_end(self, mock_ddbc_conn): + """The canonical `connect(conn_str, token_provider=DefaultAzureCredential())`.""" + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeDefaultAzureCredential() + from mssql_python import connect + + conn = connect("Server=myserver.database.windows.net;Database=mydb", token_provider=cred) + # Token acquired with the commercial-cloud database scope, once. + assert cred.calls == [(_DATABASE_SCOPE,)] + assert conn._token_provider is cred + assert conn._token_expires_on == 1893456000 + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_client_secret_credential_service_principal(self, mock_ddbc_conn): + """Service-principal pattern: ClientSecretCredential(tenant, id, secret).""" + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeClientSecretCredential( + tenant_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + client_id="11111111-2222-3333-4444-555555555555", + client_secret="super-secret", + ) + from mssql_python import connect + + conn = connect("Server=s.database.windows.net;Database=d", token_provider=cred) + assert cred.calls == 1 + assert conn._token_provider is cred + # The client secret must never end up in the (sanitized) connection string. + assert "super-secret" not in conn.connection_str + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_managed_identity_credential_app_service(self, mock_ddbc_conn): + """App Service / VM pattern: ManagedIdentityCredential(client_id=...).""" + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeManagedIdentityCredential(client_id="user-assigned-mi-client-id") + from mssql_python import connect + + conn = connect("Server=s.database.windows.net;Database=d", token_provider=cred) + assert conn._token_provider is cred + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_interactive_browser_credential_signs_in_once(self, mock_ddbc_conn): + """Interactive credential: first connect triggers the single sign-in.""" + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeInteractiveBrowserCredential() + from mssql_python import connect + + conn = connect("Server=s.database.windows.net;Database=d", token_provider=cred) + assert cred.login_count == 1 + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_access_token_namedtuple_round_trips(self, mock_ddbc_conn): + """A real AccessToken namedtuple flows through .token / .expires_on access.""" + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeDefaultAzureCredential(expires_on=1999999999) + from mssql_python import connect + + conn = connect("Server=s;Database=d", token_provider=cred) + assert conn._token_expires_on == 1999999999 + # The injected attribute is the UTF-16-LE struct, not the raw JWT. + token_struct = conn._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] + body = SAMPLE_TOKEN.encode("UTF-16-LE") + assert token_struct[:4] == len(body).to_bytes(4, "little") + assert token_struct[4:] == body + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_one_credential_reused_across_a_connection_pool(self, mock_ddbc_conn): + """The real pattern: build the credential once, reuse for every connect. + + Each connect() acquires a fresh token from the (internally-cached) + credential, and connections never share an attrs_before dict. + """ + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeDefaultAzureCredential() + from mssql_python import connect + + conns = [connect(f"Server=s{i};Database=d", token_provider=cred) for i in range(5)] + assert len(cred.calls) == 5 + # No two connections alias the same attrs_before dict (regression guard + # for the caller-dict-mutation bug). + ids = {id(c._attrs_before) for c in conns} + assert len(ids) == 5 + for c in conns: + c.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_shared_app_config_dict_reused_for_every_connection(self, mock_ddbc_conn): + """Real-world bug-fix scenario: an app holds ONE options dict (e.g. a + login timeout) and passes it to every connect() alongside a credential. + + Before the fix the first connect() injected the access token into this + shared dict, so the second connect() raised "Cannot specify both ...". + """ + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeDefaultAzureCredential() + from mssql_python import connect + + SQL_ATTR_LOGIN_TIMEOUT = 113 + app_attrs = {SQL_ATTR_LOGIN_TIMEOUT: 30} # built once, reused everywhere + + c1 = connect("Server=s1;Database=d", token_provider=cred, attrs_before=app_attrs) + c2 = connect("Server=s2;Database=d", token_provider=cred, attrs_before=app_attrs) + + # The shared dict is untouched: only the login timeout, no access token. + assert app_attrs == {SQL_ATTR_LOGIN_TIMEOUT: 30} + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value not in app_attrs + # Both connections got their own token + the app's login timeout. + for c in (c1, c2): + assert c._attrs_before[SQL_ATTR_LOGIN_TIMEOUT] == 30 + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in c._attrs_before + c1.close() + c2.close() + + class TestParseTenantId: def test_guid_tenant(self): url = "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/" diff --git a/tests/test_020_bulkcopy_auth_cleanup.py b/tests/test_020_bulkcopy_auth_cleanup.py index 16443834..1f471f1c 100644 --- a/tests/test_020_bulkcopy_auth_cleanup.py +++ b/tests/test_020_bulkcopy_auth_cleanup.py @@ -12,6 +12,10 @@ import secrets from unittest.mock import MagicMock, patch +import pytest + +from mssql_python.exceptions import OperationalError + SAMPLE_TOKEN = secrets.token_hex(44) @@ -22,6 +26,7 @@ def _make_cursor(connection_str, auth_type): mock_conn = MagicMock() mock_conn.connection_str = connection_str mock_conn._auth_type = auth_type + mock_conn._token_provider = None mock_conn._is_connected = True cursor = Cursor.__new__(Cursor) @@ -108,3 +113,168 @@ def capture_context(ctx, **kwargs): assert "access_token" not in captured_context assert captured_context.get("user_name") == "sa" assert captured_context.get("password") == "mypwd" + + +class TestBulkcopyTokenProvider: + """Verify cursor.bulkcopy acquires a token from a custom token_provider.""" + + @patch("mssql_python.cursor.logger") + def test_token_provider_replaces_auth_fields(self, mock_logger): + """token_provider present ⇒ fresh token injected, stale auth keys removed.""" + mock_logger.is_debug_enabled = False + + # Custom credential whose get_token returns an AccessToken-like object. + credential = MagicMock() + credential.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" + "Authentication=ActiveDirectoryDefault;UID=user@test.com;PWD=secret", + "activedirectorydefault", + ) + # token_provider takes precedence over _auth_type. + cursor._connection._token_provider = credential + + captured_context = {} + + mock_pycore_cursor = MagicMock() + mock_pycore_cursor.bulkcopy.return_value = { + "rows_copied": 1, + "batch_count": 1, + "elapsed_time": 0.1, + } + mock_pycore_conn = MagicMock() + mock_pycore_conn.cursor.return_value = mock_pycore_cursor + + def capture_context(ctx, **kwargs): + captured_context.update(ctx) + return mock_pycore_conn + + mock_pycore_module = MagicMock() + mock_pycore_module.PyCoreConnection = capture_context + + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) + + # The credential was consulted for a fresh token. + credential.get_token.assert_called_once() + assert captured_context.get("access_token") == SAMPLE_TOKEN + assert "authentication" not in captured_context + assert "user_name" not in captured_context + assert "password" not in captured_context + + @patch("mssql_python.cursor.logger") + def test_token_provider_get_token_failure_rewrapped(self, mock_logger): + """credential.get_token raising ⇒ bulkcopy raises OperationalError.""" + mock_logger.is_debug_enabled = False + + credential = MagicMock() + credential.get_token.side_effect = RuntimeError("network down") + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" + "Authentication=ActiveDirectoryDefault", + "activedirectorydefault", + ) + cursor._connection._token_provider = credential + + mock_pycore_module = MagicMock() + + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): + with pytest.raises(OperationalError) as exc_info: + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) + + assert "unable to acquire token from custom credential" in str(exc_info.value) + + @patch("mssql_python.cursor.logger") + def test_each_bulkcopy_reacquires_fresh_token(self, mock_logger): + """Every bulkcopy() call asks the provider for a fresh token (no reuse + of a possibly-stale token across operations).""" + mock_logger.is_debug_enabled = False + + credential = MagicMock() + credential.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" + "Authentication=ActiveDirectoryDefault", + "activedirectorydefault", + ) + cursor._connection._token_provider = credential + + mock_pycore_cursor = MagicMock() + mock_pycore_cursor.bulkcopy.return_value = { + "rows_copied": 1, + "batch_count": 1, + "elapsed_time": 0.1, + } + mock_pycore_conn = MagicMock() + mock_pycore_conn.cursor.return_value = mock_pycore_cursor + mock_pycore_module = MagicMock() + mock_pycore_module.PyCoreConnection = lambda ctx, **kwargs: mock_pycore_conn + + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): + cursor.bulkcopy("dbo.t", [(1, "a")], timeout=10) + cursor.bulkcopy("dbo.t", [(2, "b")], timeout=10) + cursor.bulkcopy("dbo.t", [(3, "c")], timeout=10) + + assert credential.get_token.call_count == 3 + + @patch("mssql_python.cursor.logger") + def test_transient_failure_then_recovery(self, mock_logger): + """A transient provider failure on one bulkcopy raises OperationalError + but leaves the cursor usable for a subsequent successful call.""" + mock_logger.is_debug_enabled = False + + credential = MagicMock() + good = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + # First call fails, second call succeeds. + credential.get_token.side_effect = [RuntimeError("network blip"), good] + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" + "Authentication=ActiveDirectoryDefault", + "activedirectorydefault", + ) + cursor._connection._token_provider = credential + + mock_pycore_cursor = MagicMock() + mock_pycore_cursor.bulkcopy.return_value = { + "rows_copied": 1, + "batch_count": 1, + "elapsed_time": 0.1, + } + mock_pycore_conn = MagicMock() + mock_pycore_conn.cursor.return_value = mock_pycore_cursor + mock_pycore_module = MagicMock() + mock_pycore_module.PyCoreConnection = lambda ctx, **kwargs: mock_pycore_conn + + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): + with pytest.raises(OperationalError): + cursor.bulkcopy("dbo.t", [(1, "a")], timeout=10) + # Cursor still works after the transient failure. + cursor.bulkcopy("dbo.t", [(2, "b")], timeout=10) + + assert credential.get_token.call_count == 2 + + mock_logger.is_debug_enabled = False + + # .token is not a non-empty string ⇒ _get_token_from_credential raises InterfaceError, + # which cursor.bulkcopy catches and re-wraps as OperationalError. + credential = MagicMock() + credential.get_token.return_value = MagicMock(token="", expires_on=None) + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" + "Authentication=ActiveDirectoryDefault", + "activedirectorydefault", + ) + cursor._connection._token_provider = credential + + mock_pycore_module = MagicMock() + + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): + with pytest.raises(OperationalError) as exc_info: + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) + + assert "unable to acquire token from custom credential" in str(exc_info.value)