diff --git a/agentrun/knowledgebase/api/__data_async_template.py b/agentrun/knowledgebase/api/__data_async_template.py index 2c770d1..dda18a6 100644 --- a/agentrun/knowledgebase/api/__data_async_template.py +++ b/agentrun/knowledgebase/api/__data_async_template.py @@ -647,12 +647,15 @@ def __init__( self.retrieve_settings = retrieve_settings def _build_agent_storage_client( - self, config: Optional[Config] = None + self, + config: Optional[Config] = None, + ots_endpoint: Optional[str] = None, ) -> AgentStorageClient: """构建 AgentStorageClient / Build AgentStorageClient Args: config: 配置 / Configuration + ots_endpoint: OTS 访问域名 / OTS endpoint Returns: AgentStorageClient: OTS 存储客户端 @@ -661,18 +664,30 @@ def _build_agent_storage_client( raise ValueError("provider_settings is required for OTS retrieval") cfg = Config.with_configs(self.config, config) - ots_endpoint = cfg.get_ots_endpoint( - self.provider_settings.ots_instance_name - ) + ots_instance_name = self.provider_settings.ots_instance_name + if ots_endpoint is None: + ots_endpoint = cfg.get_ots_endpoint(ots_instance_name) return AgentStorageClient( access_key_id=cfg.get_access_key_id(), access_key_secret=cfg.get_access_key_secret(), sts_token=cfg.get_security_token(), ots_endpoint=ots_endpoint, - ots_instance_name=self.provider_settings.ots_instance_name, + ots_instance_name=ots_instance_name, ) + def _build_frontend_ots_endpoint( + self, config: Optional[Config] = None + ) -> str: + cfg = Config.with_configs(self.config, config) + return f"http://ots-{cfg.get_region_id()}.aliyuncs.com" + + def _can_fallback_to_frontend_ots_endpoint( + self, config: Optional[Config] = None + ) -> bool: + cfg = Config.with_configs(self.config, config) + return not cfg.get_use_vpc_endpoint() + def _build_retrieval_configuration( self, filter: Optional[Dict[str, Any]] = None ) -> Optional[Dict[str, Any]]: @@ -838,7 +853,22 @@ async def retrieve_async( request["retrievalConfiguration"] = retrieval_config logger.debug(f"OTS retrieve request: {request}") - response = client.retrieve(request) + try: + response = client.retrieve(request) + except Exception as instance_error: + if not self._can_fallback_to_frontend_ots_endpoint(config): + raise + frontend_endpoint = self._build_frontend_ots_endpoint(config) + logger.warning( + "Failed to retrieve from OTS knowledge base " + f"'{self.knowledge_base_name}' with instance endpoint, " + f"fallback to frontend endpoint {frontend_endpoint}: " + f"{instance_error}" + ) + client = self._build_agent_storage_client( + config, ots_endpoint=frontend_endpoint + ) + response = client.retrieve(request) logger.debug(f"OTS retrieve response: {response}") return self._parse_retrieve_response(response, query) diff --git a/agentrun/knowledgebase/api/data.py b/agentrun/knowledgebase/api/data.py index ca0ac35..aa3c05a 100644 --- a/agentrun/knowledgebase/api/data.py +++ b/agentrun/knowledgebase/api/data.py @@ -927,12 +927,15 @@ def __init__( self.retrieve_settings = retrieve_settings def _build_agent_storage_client( - self, config: Optional[Config] = None + self, + config: Optional[Config] = None, + ots_endpoint: Optional[str] = None, ) -> AgentStorageClient: """构建 AgentStorageClient / Build AgentStorageClient Args: config: 配置 / Configuration + ots_endpoint: OTS 访问域名 / OTS endpoint Returns: AgentStorageClient: OTS 存储客户端 @@ -941,18 +944,30 @@ def _build_agent_storage_client( raise ValueError("provider_settings is required for OTS retrieval") cfg = Config.with_configs(self.config, config) - ots_endpoint = cfg.get_ots_endpoint( - self.provider_settings.ots_instance_name - ) + ots_instance_name = self.provider_settings.ots_instance_name + if ots_endpoint is None: + ots_endpoint = cfg.get_ots_endpoint(ots_instance_name) return AgentStorageClient( access_key_id=cfg.get_access_key_id(), access_key_secret=cfg.get_access_key_secret(), sts_token=cfg.get_security_token(), ots_endpoint=ots_endpoint, - ots_instance_name=self.provider_settings.ots_instance_name, + ots_instance_name=ots_instance_name, ) + def _build_frontend_ots_endpoint( + self, config: Optional[Config] = None + ) -> str: + cfg = Config.with_configs(self.config, config) + return f"http://ots-{cfg.get_region_id()}.aliyuncs.com" + + def _can_fallback_to_frontend_ots_endpoint( + self, config: Optional[Config] = None + ) -> bool: + cfg = Config.with_configs(self.config, config) + return not cfg.get_use_vpc_endpoint() + def _build_retrieval_configuration( self, filter: Optional[Dict[str, Any]] = None ) -> Optional[Dict[str, Any]]: @@ -1118,7 +1133,22 @@ async def retrieve_async( request["retrievalConfiguration"] = retrieval_config logger.debug(f"OTS retrieve request: {request}") - response = client.retrieve(request) + try: + response = client.retrieve(request) + except Exception as instance_error: + if not self._can_fallback_to_frontend_ots_endpoint(config): + raise + frontend_endpoint = self._build_frontend_ots_endpoint(config) + logger.warning( + "Failed to retrieve from OTS knowledge base " + f"'{self.knowledge_base_name}' with instance endpoint, " + f"fallback to frontend endpoint {frontend_endpoint}: " + f"{instance_error}" + ) + client = self._build_agent_storage_client( + config, ots_endpoint=frontend_endpoint + ) + response = client.retrieve(request) logger.debug(f"OTS retrieve response: {response}") return self._parse_retrieve_response(response, query) @@ -1174,7 +1204,22 @@ def retrieve( request["retrievalConfiguration"] = retrieval_config logger.debug(f"OTS retrieve request: {request}") - response = client.retrieve(request) + try: + response = client.retrieve(request) + except Exception as instance_error: + if not self._can_fallback_to_frontend_ots_endpoint(config): + raise + frontend_endpoint = self._build_frontend_ots_endpoint(config) + logger.warning( + "Failed to retrieve from OTS knowledge base " + f"'{self.knowledge_base_name}' with instance endpoint, " + f"fallback to frontend endpoint {frontend_endpoint}: " + f"{instance_error}" + ) + client = self._build_agent_storage_client( + config, ots_endpoint=frontend_endpoint + ) + response = client.retrieve(request) logger.debug(f"OTS retrieve response: {response}") return self._parse_retrieve_response(response, query) diff --git a/agentrun/utils/config.py b/agentrun/utils/config.py index 3a9fe4c..ba31276 100644 --- a/agentrun/utils/config.py +++ b/agentrun/utils/config.py @@ -343,7 +343,7 @@ def get_ots_endpoint(self, instance_name: str) -> str: return ( f"https://{instance_name}.{region_id}.vpc.tablestore.aliyuncs.com" ) - return f"http://ots-{region_id}.aliyuncs.com" + return f"https://{instance_name}.{region_id}.ots.aliyuncs.com" def get_use_vpc_endpoint(self) -> bool: """知识库检索是否使用 VPC 内网 endpoint""" diff --git a/tests/unittests/knowledgebase/test_ots_knowledgebase.py b/tests/unittests/knowledgebase/test_ots_knowledgebase.py index 17c18bb..bbd732c 100644 --- a/tests/unittests/knowledgebase/test_ots_knowledgebase.py +++ b/tests/unittests/knowledgebase/test_ots_knowledgebase.py @@ -493,6 +493,104 @@ def test_retrieve_error_handling(self, mock_build_client): assert result["query"] == "test query" assert result["knowledge_base_name"] == "test-kb" + @patch("agentrun.knowledgebase.api.data.AgentStorageClient") + def test_retrieve_falls_back_to_frontend_endpoint( + self, mock_client_class + ): + """测试实例域名检索失败后回退到大前端域名""" + mock_config = MagicMock(spec=Config) + mock_config.get_region_id.return_value = "cn-hangzhou" + mock_config.get_use_vpc_endpoint.return_value = False + mock_config.get_ots_endpoint.return_value = ( + "https://test-instance.cn-hangzhou.ots.aliyuncs.com" + ) + mock_config.get_access_key_id.return_value = "test-ak" + mock_config.get_access_key_secret.return_value = "test-sk" + mock_config.get_security_token.return_value = "test-sts" + + instance_client = MagicMock() + instance_client.retrieve.side_effect = RuntimeError( + "instance endpoint unavailable" + ) + fallback_client = MagicMock() + fallback_client.retrieve.return_value = { + "code": "SUCCESS", + "data": { + "retrievalResults": [ + {"content": "fallback result", "score": 0.8} + ] + }, + } + mock_client_class.side_effect = [instance_client, fallback_client] + + with patch.object(Config, "with_configs", return_value=mock_config): + api = OTSDataAPI( + "test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + result = api.retrieve("test query") + + assert "error" not in result + assert result["data"][0]["content"] == "fallback result" + assert mock_client_class.call_args_list[0].kwargs["ots_endpoint"] == ( + "https://test-instance.cn-hangzhou.ots.aliyuncs.com" + ) + assert mock_client_class.call_args_list[1].kwargs["ots_endpoint"] == ( + "http://ots-cn-hangzhou.aliyuncs.com" + ) + instance_client.retrieve.assert_called_once() + fallback_client.retrieve.assert_called_once() + + @patch("agentrun.knowledgebase.api.data.AgentStorageClient") + def test_retrieve_vpc_failure_does_not_fallback_to_public_endpoint( + self, mock_client_class + ): + """测试 VPC endpoint 失败时不回退到公网大前端域名""" + mock_config = MagicMock(spec=Config) + mock_config.get_region_id.return_value = "cn-hangzhou" + mock_config.get_use_vpc_endpoint.return_value = True + mock_config.get_ots_endpoint.return_value = ( + "https://test-instance.cn-hangzhou.vpc.tablestore.aliyuncs.com" + ) + mock_config.get_access_key_id.return_value = "test-ak" + mock_config.get_access_key_secret.return_value = "test-sk" + mock_config.get_security_token.return_value = "test-sts" + + instance_client = MagicMock() + instance_client.retrieve.side_effect = RuntimeError( + "vpc endpoint unavailable" + ) + fallback_client = MagicMock() + fallback_client.retrieve.return_value = { + "code": "SUCCESS", + "data": { + "retrievalResults": [ + {"content": "public fallback result", "score": 0.8} + ] + }, + } + mock_client_class.side_effect = [instance_client, fallback_client] + + with patch.object(Config, "with_configs", return_value=mock_config): + api = OTSDataAPI( + "test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + result = api.retrieve("test query") + + assert result["error"] is True + assert "vpc endpoint unavailable" in result["data"] + assert mock_client_class.call_count == 1 + assert mock_client_class.call_args.kwargs["ots_endpoint"] == ( + "https://test-instance.cn-hangzhou.vpc.tablestore.aliyuncs.com" + ) + instance_client.retrieve.assert_called_once() + fallback_client.retrieve.assert_not_called() + @patch( "agentrun.knowledgebase.api.data.OTSDataAPI._build_agent_storage_client" ) @@ -512,6 +610,106 @@ async def test_retrieve_async_error_handling(self, mock_build_client): assert result["error"] is True assert "Failed to retrieve" in result["data"] + @patch("agentrun.knowledgebase.api.data.AgentStorageClient") + @pytest.mark.asyncio + async def test_retrieve_async_falls_back_to_frontend_endpoint( + self, mock_client_class + ): + """测试异步实例域名检索失败后回退到大前端域名""" + mock_config = MagicMock(spec=Config) + mock_config.get_region_id.return_value = "cn-hangzhou" + mock_config.get_use_vpc_endpoint.return_value = False + mock_config.get_ots_endpoint.return_value = ( + "https://test-instance.cn-hangzhou.ots.aliyuncs.com" + ) + mock_config.get_access_key_id.return_value = "test-ak" + mock_config.get_access_key_secret.return_value = "test-sk" + mock_config.get_security_token.return_value = "test-sts" + + instance_client = MagicMock() + instance_client.retrieve.side_effect = RuntimeError( + "instance endpoint unavailable" + ) + fallback_client = MagicMock() + fallback_client.retrieve.return_value = { + "code": "SUCCESS", + "data": { + "retrievalResults": [ + {"content": "async fallback result", "score": 0.8} + ] + }, + } + mock_client_class.side_effect = [instance_client, fallback_client] + + with patch.object(Config, "with_configs", return_value=mock_config): + api = OTSDataAPI( + "test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + result = await api.retrieve_async("test query") + + assert "error" not in result + assert result["data"][0]["content"] == "async fallback result" + assert mock_client_class.call_args_list[0].kwargs["ots_endpoint"] == ( + "https://test-instance.cn-hangzhou.ots.aliyuncs.com" + ) + assert mock_client_class.call_args_list[1].kwargs["ots_endpoint"] == ( + "http://ots-cn-hangzhou.aliyuncs.com" + ) + instance_client.retrieve.assert_called_once() + fallback_client.retrieve.assert_called_once() + + @patch("agentrun.knowledgebase.api.data.AgentStorageClient") + @pytest.mark.asyncio + async def test_retrieve_async_vpc_failure_does_not_fallback_to_public_endpoint( + self, mock_client_class + ): + """测试异步 VPC endpoint 失败时不回退到公网大前端域名""" + mock_config = MagicMock(spec=Config) + mock_config.get_region_id.return_value = "cn-hangzhou" + mock_config.get_use_vpc_endpoint.return_value = True + mock_config.get_ots_endpoint.return_value = ( + "https://test-instance.cn-hangzhou.vpc.tablestore.aliyuncs.com" + ) + mock_config.get_access_key_id.return_value = "test-ak" + mock_config.get_access_key_secret.return_value = "test-sk" + mock_config.get_security_token.return_value = "test-sts" + + instance_client = MagicMock() + instance_client.retrieve.side_effect = RuntimeError( + "vpc endpoint unavailable" + ) + fallback_client = MagicMock() + fallback_client.retrieve.return_value = { + "code": "SUCCESS", + "data": { + "retrievalResults": [ + {"content": "async public fallback result", "score": 0.8} + ] + }, + } + mock_client_class.side_effect = [instance_client, fallback_client] + + with patch.object(Config, "with_configs", return_value=mock_config): + api = OTSDataAPI( + "test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + result = await api.retrieve_async("test query") + + assert result["error"] is True + assert "vpc endpoint unavailable" in result["data"] + assert mock_client_class.call_count == 1 + assert mock_client_class.call_args.kwargs["ots_endpoint"] == ( + "https://test-instance.cn-hangzhou.vpc.tablestore.aliyuncs.com" + ) + instance_client.retrieve.assert_called_once() + fallback_client.retrieve.assert_not_called() + def test_retrieve_without_provider_settings(self): """测试无 provider_settings 时检索""" api = OTSDataAPI("test-kb") @@ -598,7 +796,7 @@ def test_build_client(self, mock_client_class): """测试构建客户端""" mock_config = MagicMock(spec=Config) mock_config.get_ots_endpoint.return_value = ( - "http://ots-cn-hangzhou.aliyuncs.com" + "https://test-instance.cn-hangzhou.ots.aliyuncs.com" ) mock_config.get_access_key_id.return_value = "test-ak" mock_config.get_access_key_secret.return_value = "test-sk" @@ -617,7 +815,7 @@ def test_build_client(self, mock_client_class): access_key_id="test-ak", access_key_secret="test-sk", sts_token="test-sts", - ots_endpoint="http://ots-cn-hangzhou.aliyuncs.com", + ots_endpoint="https://test-instance.cn-hangzhou.ots.aliyuncs.com", ots_instance_name="test-instance", ) diff --git a/tests/unittests/utils/test_config.py b/tests/unittests/utils/test_config.py index 029293c..9003b78 100644 --- a/tests/unittests/utils/test_config.py +++ b/tests/unittests/utils/test_config.py @@ -35,7 +35,7 @@ def test_kb_endpoints_default_public(self): assert config.get_bailian_endpoint() == "https://bailian.cn-beijing.aliyuncs.com" assert config.get_gpdb_endpoint() == "gpdb.aliyuncs.com" assert config.get_ots_endpoint("my-instance") == ( - "http://ots-cn-hangzhou.aliyuncs.com" + "https://my-instance.cn-hangzhou.ots.aliyuncs.com" ) def test_kb_endpoints_vpc_mode(self):