diff --git a/.pylintrc b/.pylintrc index ece807a2..7dc311f7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -336,7 +336,7 @@ exclude-protected=_asdict,_fields,_replace,_source,_make [DESIGN] # Maximum number of arguments for function / method -max-args=7 +max-args=8 # Argument names that match this expression will be ignored. Default to name # with leading underscore @@ -358,7 +358,7 @@ max-statements=50 max-parents=7 # Maximum number of attributes for a class (see R0902). -max-attributes=7 +max-attributes=8 # Minimum number of public methods for a class (see R0903). min-public-methods=1 diff --git a/CHANGELOG.md b/CHANGELOG.md index f952c045..62af5662 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] +- service: `response_hook` parameter, enables inspection or rejection of raw responses (e.g. header-encoded SAP domain errors) without leaking HTTP transport objects through the OData API boundary. +- vendor/SAP: `sap_header_error_hook(response)` — a stateless hook that detects SAP domain errors encoded in the `sap-message` response header and raises `BusinessGatewayError` before pyodata's domain handler runs. - service: let FunctionRequests return a list of EntityProxies instead of the raw json, when the `ReturnType` is a Collection. - Emil B. ## [1.11.2] diff --git a/docs/usage/advanced.rst b/docs/usage/advanced.rst index cb361ada..8e65c373 100644 --- a/docs/usage/advanced.rst +++ b/docs/usage/advanced.rst @@ -105,3 +105,32 @@ If you need to work with many Entity Sets the same way or if you just need to pi count = getattr(northwind.entity_sets, 'Employees').get_entities().count().execute() print(count) + +Inspecting responses with a hook +--------------------------------- + +Some OData services communicate domain errors via HTTP response headers on otherwise-200 +responses. Because pyodata discards headers before returning the domain result, callers +cannot detect these errors through the normal return value. + +``Client`` (and ``Service`` directly) accept an optional ``response_hook`` parameter — a +``Callable[[response], None]`` that fires before the domain handler runs, for every request +type (including ``async_execute()``). The hook receives the raw response object. Raising an +exception from the hook propagates to the caller and prevents the domain handler from running. + +.. code-block:: python + + import pyodata + import requests + + SERVICE_URL = 'https://odata.example.com/MyService.svc' + + def my_hook(response): + if response.headers.get('x-custom-error'): + raise RuntimeError(f"Service signalled error: {response.headers['x-custom-error']}") + + service = pyodata.Client(SERVICE_URL, requests.Session(), response_hook=my_hook) + +The hook must be stateless to be safe under concurrent and async use. If you need to +handle SAP-specific header errors, use the ready-made hook in ``pyodata.vendor.SAP`` +— see :doc:`vendors`. diff --git a/docs/usage/vendors.rst b/docs/usage/vendors.rst index 7079b5c3..4d58545d 100644 --- a/docs/usage/vendors.rst +++ b/docs/usage/vendors.rst @@ -34,3 +34,35 @@ The following code demonstrates using the helper. session = SAP.add_btp_token_to_session(requests.Session(), KEY, USER, PASSWORD) # do something more with session object if necessary (e.g. adding sap-client parameter, or CSRF token) client = pyodata.Client(SERVICE_URL, session) + +Detecting SAP domain errors in response headers +------------------------------------------------ + +Some SAP OData services signal domain errors via the ``sap-message`` response header on +otherwise-200 responses. pyodata's domain handler never sees these headers, so callers +cannot detect the error through the normal return value. + +``pyodata.vendor.SAP`` provides a ready-made stateless hook, ``sap_header_error_hook``, +that reads the ``sap-message`` header, parses it as JSON, and raises ``BusinessGatewayError`` +when the ``severity`` field equals ``"error"``. Pass it as the ``response_hook`` argument to +``Client`` (or directly to ``Service``): + +.. code-block:: python + + import pyodata + from pyodata.vendor.SAP import sap_header_error_hook + import requests + + SERVICE_URL = 'https://example.com/sap/opu/odata/sap/ZMyService' + + session = requests.Session() + client = pyodata.Client(SERVICE_URL, session, response_hook=sap_header_error_hook) + + try: + result = client.entity_sets.Employees.get_entity(1).execute() + except pyodata.vendor.SAP.BusinessGatewayError as ex: + print(f"SAP domain error: {ex}") + +The hook fires before pyodata's domain handler, so the exception propagates before any +result object is constructed. It is safe under concurrent and async use because it holds +no instance state. diff --git a/pyodata/client.py b/pyodata/client.py index f2383fc8..056ad24b 100644 --- a/pyodata/client.py +++ b/pyodata/client.py @@ -55,7 +55,8 @@ class Client: @staticmethod async def build_async_client(url, connection, odata_version=ODATA_VERSION_2, namespaces=None, - config: pyodata.v2.model.Config = None, metadata: str = None): + config: pyodata.v2.model.Config = None, metadata: str = None, + response_hook=None): """Create instance of the OData Client for given URL""" logger = logging.getLogger('pyodata.client') @@ -69,11 +70,12 @@ async def build_async_client(url, connection, odata_version=ODATA_VERSION_2, nam metadata = await _async_fetch_metadata(connection, url, logger) else: logger.info('Using static metadata') - return Client._build_service(logger, url, connection, odata_version, namespaces, config, metadata) + return Client._build_service(logger, url, connection, odata_version, namespaces, config, metadata, + response_hook=response_hook) raise PyODataException(f'No implementation for selected odata version {odata_version}') def __new__(cls, url, connection, odata_version=ODATA_VERSION_2, namespaces=None, - config: pyodata.v2.model.Config = None, metadata: str = None): + config: pyodata.v2.model.Config = None, metadata: str = None, response_hook=None): """Create instance of the OData Client for given URL""" logger = logging.getLogger('pyodata.client') @@ -88,12 +90,13 @@ def __new__(cls, url, connection, odata_version=ODATA_VERSION_2, namespaces=None else: logger.info('Using static metadata') - return Client._build_service(logger, url, connection, odata_version, namespaces, config, metadata) + return Client._build_service(logger, url, connection, odata_version, namespaces, config, metadata, + response_hook=response_hook) raise PyODataException(f'No implementation for selected odata version {odata_version}') @staticmethod def _build_service(logger, url, connection, odata_version=ODATA_VERSION_2, namespaces=None, - config: pyodata.v2.model.Config = None, metadata: str = None): + config: pyodata.v2.model.Config = None, metadata: str = None, response_hook=None): if config is not None and namespaces is not None: raise PyODataException('You cannot pass namespaces and config at the same time') @@ -111,6 +114,6 @@ def _build_service(logger, url, connection, odata_version=ODATA_VERSION_2, names # create service instance based on model we have logger.info('Creating OData Service (version: %d)', odata_version) - service = pyodata.v2.service.Service(url, schema, connection, config=config) + service = pyodata.v2.service.Service(url, schema, connection, config=config, response_hook=response_hook) return service diff --git a/pyodata/v2/service.py b/pyodata/v2/service.py index b460e0b1..d170c1e8 100644 --- a/pyodata/v2/service.py +++ b/pyodata/v2/service.py @@ -232,7 +232,7 @@ def __repr__(self): class ODataHttpRequest: """Deferred HTTP Request""" - def __init__(self, url, connection, handler, headers=None): + def __init__(self, url, connection, handler, headers=None, response_hook=None): self._connection = connection self._url = url self._handler = handler @@ -240,6 +240,7 @@ def __init__(self, url, connection, handler, headers=None): self._logger = logging.getLogger(LOGGER_NAME) self._customs = {} # string -> string hash self._next_url = None + self._response_hook = response_hook @property def handler(self): @@ -359,6 +360,9 @@ def _call_handler(self, response): except UnicodeDecodeError: self._logger.debug(' body: ') + if self._response_hook is not None: + self._response_hook(response) + return self._handler(response) def custom(self, name, value): @@ -373,7 +377,7 @@ class EntityGetRequest(ODataHttpRequest): def __init__(self, handler, entity_key, entity_set_proxy, encode_path=True): super(EntityGetRequest, self).__init__(entity_set_proxy.service.url, entity_set_proxy.service.connection, - handler) + handler, response_hook=entity_set_proxy.service.response_hook) self._logger = logging.getLogger(LOGGER_NAME) self._entity_key = entity_key self._entity_set_proxy = entity_set_proxy @@ -465,8 +469,8 @@ class EntityCreateRequest(ODataHttpRequest): Call execute() to send the create-request to the OData service and get the newly created entity.""" - def __init__(self, url, connection, handler, entity_set, last_segment=None): - super(EntityCreateRequest, self).__init__(url, connection, handler) + def __init__(self, url, connection, handler, entity_set, last_segment=None, response_hook=None): + super(EntityCreateRequest, self).__init__(url, connection, handler, response_hook=response_hook) self._logger = logging.getLogger(LOGGER_NAME) self._entity_set = entity_set self._entity_type = entity_set.entity_type @@ -552,8 +556,8 @@ def set(self, **kwargs): class EntityDeleteRequest(ODataHttpRequest): """Used for deleting entity (DELETE operations on a single entity)""" - def __init__(self, url, connection, handler, entity_set, entity_key, encode_path=True): - super(EntityDeleteRequest, self).__init__(url, connection, handler) + def __init__(self, url, connection, handler, entity_set, entity_key, encode_path=True, response_hook=None): + super(EntityDeleteRequest, self).__init__(url, connection, handler, response_hook=response_hook) self._logger = logging.getLogger(LOGGER_NAME) self._entity_set = entity_set self._entity_key = entity_key @@ -585,8 +589,9 @@ class EntityModifyRequest(ODataHttpRequest): ALLOWED_HTTP_METHODS = ['PATCH', 'PUT', 'MERGE'] # pylint: disable=too-many-arguments - def __init__(self, url, connection, handler, entity_set, entity_key, method="PATCH", encode_path=True): - super(EntityModifyRequest, self).__init__(url, connection, handler) + def __init__(self, url, connection, handler, entity_set, entity_key, method="PATCH", encode_path=True, + response_hook=None): + super(EntityModifyRequest, self).__init__(url, connection, handler, response_hook=response_hook) self._logger = logging.getLogger(LOGGER_NAME) self._entity_set = entity_set self._entity_type = entity_set.entity_type @@ -650,8 +655,8 @@ class QueryRequest(ODataHttpRequest): # pylint: disable=too-many-instance-attributes - def __init__(self, url, connection, handler, last_segment): - super(QueryRequest, self).__init__(url, connection, handler) + def __init__(self, url, connection, handler, last_segment, response_hook=None): + super(QueryRequest, self).__init__(url, connection, handler, response_hook=response_hook) self._logger = logging.getLogger(LOGGER_NAME) self._count = None @@ -767,8 +772,10 @@ def get_query_params(self): class FunctionRequest(QueryRequest): """Function import request (Service call)""" - def __init__(self, url, connection, handler, function_import): - super(FunctionRequest, self).__init__(url, connection, handler, function_import.name) + def __init__(self, url, connection, handler, function_import, response_hook=None): + super(FunctionRequest, self).__init__( + url, connection, handler, function_import.name, + response_hook=response_hook) self._function_import = function_import @@ -1332,8 +1339,8 @@ def __str__(self): class GetEntitySetRequest(QueryRequest): """GET on EntitySet""" - def __init__(self, url, connection, handler, last_segment, entity_type, encode_path=True): - super(GetEntitySetRequest, self).__init__(url, connection, handler, last_segment) + def __init__(self, url, connection, handler, last_segment, entity_type, encode_path=True, response_hook=None): + super(GetEntitySetRequest, self).__init__(url, connection, handler, last_segment, response_hook=response_hook) self._entity_type = entity_type self._encode_path = encode_path @@ -1554,7 +1561,7 @@ def get_entities_handler(response): entity_set_name = self._alias if self._alias is not None else self._entity_set.name return GetEntitySetRequest(self._service.url, self._service.connection, get_entities_handler, self._parent_last_segment + entity_set_name, self._entity_set.entity_type, - encode_path=encode_path) + encode_path=encode_path, response_hook=self._service.response_hook) def create_entity(self, return_code=HTTP_CODE_CREATED): """Creates a new entity in the given entity-set.""" @@ -1572,7 +1579,7 @@ def create_entity_handler(response): return EntityProxy(self._service, self._entity_set, self._entity_set.entity_type, entity_props, etag=etag) return EntityCreateRequest(self._service.url, self._service.connection, create_entity_handler, self._entity_set, - self.last_segment) + self.last_segment, response_hook=self._service.response_hook) def update_entity(self, key=None, method=None, encode_path=True, **kwargs): """Updates an existing entity in the given entity-set.""" @@ -1595,7 +1602,8 @@ def update_entity_handler(response): method = self._service.config['http']['update_method'] return EntityModifyRequest(self._service.url, self._service.connection, update_entity_handler, self._entity_set, - entity_key, method=method, encode_path=encode_path) + entity_key, method=method, encode_path=encode_path, + response_hook=self._service.response_hook) def delete_entity(self, key: EntityKey = None, encode_path=True, **kwargs): """Delete the entity""" @@ -1614,7 +1622,7 @@ def delete_entity_handler(response): entity_key = EntityKey(self._entity_set.entity_type, key, **kwargs) return EntityDeleteRequest(self._service.url, self._service.connection, delete_entity_handler, self._entity_set, - entity_key, encode_path=encode_path) + entity_key, encode_path=encode_path, response_hook=self._service.response_hook) # pylint: disable=too-few-public-methods @@ -1735,17 +1743,19 @@ def function_import_handler(fimport, response): return response_data return FunctionRequest(self._service.url, self._service.connection, - partial(function_import_handler, fimport), fimport) + partial(function_import_handler, fimport), fimport, + response_hook=self._service.response_hook) class Service: """OData service""" - def __init__(self, url, schema, connection, config=None): + def __init__(self, url, schema, connection, config=None, response_hook=None): self._url = url self._schema = schema self._connection = connection self._retain_null = config.retain_null if config else False + self._response_hook = response_hook self._entity_container = EntityContainer(self) self._function_container = FunctionContainer(self) @@ -1769,6 +1779,12 @@ def connection(self): return self._connection + @property + def response_hook(self): + """Optional hook called with the raw response before domain handler runs""" + + return self._response_hook + @property def retain_null(self): """Whether to respect null-ed values or to substitute them with type specific default values""" diff --git a/pyodata/vendor/SAP.py b/pyodata/vendor/SAP.py index 070f46cb..ef78b91f 100644 --- a/pyodata/vendor/SAP.py +++ b/pyodata/vendor/SAP.py @@ -47,6 +47,29 @@ def add_btp_token_to_session(session, key, user, password): return session +def sap_header_error_hook(response): + """Response hook that detects SAP domain errors encoded in the sap-message header + on otherwise-200 responses. + + Pass this as response_hook to Service() to raise BusinessGatewayError before + pyodata's domain handler runs: + + service = Service(url, schema, session, response_hook=sap_header_error_hook) + """ + sap_message = response.headers.get('sap-message') + if sap_message is None: + return + + try: + msg = json.loads(sap_message) + except ValueError: + return + + severity = msg.get('severity', '') + if severity == 'error': + raise BusinessGatewayError(msg.get('message', 'SAP header error'), response) + + class BusinessGatewayError(HttpError): """To display the right error message""" diff --git a/tests/test_service_v2.py b/tests/test_service_v2.py index 083c3a42..960b244c 100644 --- a/tests/test_service_v2.py +++ b/tests/test_service_v2.py @@ -3090,4 +3090,54 @@ def test_custom_with_create_entity_url_params(service): assert result.Key == '12345' assert result.Data == 'abcd' - assert result.etag == 'W/\"J0FtZXJpY2FuIEFpcmxpbmVzJw==\"' \ No newline at end of file + assert result.etag == 'W/\"J0FtZXJpY2FuIEFpcmxpbmVzJw==\"' + + +@responses.activate +def test_response_hook_is_called_before_domain_handler(schema): + """response_hook fires before the domain handler and receives the raw response""" + + called_with = [] + + def hook(response): + called_with.append(response.status_code) + + svc = pyodata.v2.service.Service(URL_ROOT, schema, requests, response_hook=hook) + + path = quote("MasterEntities('1')") + responses.add( + responses.GET, + f"{URL_ROOT}/{path}", + headers={'Content-type': 'application/json'}, + json={'d': {'Key': '1', 'Data': 'x'}}, + status=200) + + svc.entity_sets.MasterEntities.get_entity('1').execute() + + assert called_with == [200] + + +@responses.activate +def test_response_hook_raising_prevents_domain_handler(schema): + """An exception raised in response_hook propagates and the domain result is not returned""" + + def hook(response): + raise RuntimeError('hook blocked this') + + svc = pyodata.v2.service.Service(URL_ROOT, schema, requests, response_hook=hook) + + path = quote("MasterEntities('1')") + responses.add( + responses.GET, + f"{URL_ROOT}/{path}", + headers={'Content-type': 'application/json'}, + json={'d': {'Key': '1', 'Data': 'x'}}, + status=200) + + with pytest.raises(RuntimeError, match='hook blocked this'): + svc.entity_sets.MasterEntities.get_entity('1').execute() + + +def test_service_without_response_hook_works(service): + """response_hook defaults to None and does not affect normal operation""" + assert service.response_hook is None \ No newline at end of file diff --git a/tests/test_vendor_sap.py b/tests/test_vendor_sap.py index 916afd54..656e2647 100644 --- a/tests/test_vendor_sap.py +++ b/tests/test_vendor_sap.py @@ -271,3 +271,37 @@ def test_add_btp_token_to_session_invalid_clientid(): assert caught.value.response.status_code == 401 assert json.loads(caught.value.response.text)['error_description'] == 'Bad credentials' + + +class MockHeaderResponse(NamedTuple): + headers: dict + content: ByteString = b'' + status_code: int = 200 + + +def test_sap_header_error_hook_no_header(): + """Hook does nothing when sap-message header is absent""" + response = MockHeaderResponse(headers={}) + SAP.sap_header_error_hook(response) # must not raise + + +def test_sap_header_error_hook_non_error_severity(): + """Hook does nothing for sap-message with severity != error""" + msg = json.dumps({'severity': 'warning', 'message': 'just a warning'}) + response = MockHeaderResponse(headers={'sap-message': msg}) + SAP.sap_header_error_hook(response) # must not raise + + +def test_sap_header_error_hook_invalid_json(): + """Hook does nothing when sap-message is not valid JSON""" + response = MockHeaderResponse(headers={'sap-message': 'not-json'}) + SAP.sap_header_error_hook(response) # must not raise + + +def test_sap_header_error_hook_raises_on_error_severity(): + """Hook raises BusinessGatewayError when sap-message severity is error""" + msg = json.dumps({'severity': 'error', 'message': 'Domain error from header'}) + response = MockHeaderResponse(headers={'sap-message': msg}) + with pytest.raises(SAP.BusinessGatewayError) as caught: + SAP.sap_header_error_hook(response) + assert 'Domain error from header' in str(caught.value)