Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions agentrun/knowledgebase/api/__data_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,12 +647,15 @@ def __init__(
self.retrieve_settings = retrieve_settings

def _build_agent_storage_client(
self, config: Optional[Config] = None
self,
config: Optional[Config] = None,
ots_endpoint: Optional[str] = None,
) -> AgentStorageClient:
"""构建 AgentStorageClient / Build AgentStorageClient

Args:
config: 配置 / Configuration
ots_endpoint: OTS 访问域名 / OTS endpoint

Returns:
AgentStorageClient: OTS 存储客户端
Expand All @@ -661,18 +664,30 @@ def _build_agent_storage_client(
raise ValueError("provider_settings is required for OTS retrieval")

cfg = Config.with_configs(self.config, config)
ots_endpoint = cfg.get_ots_endpoint(
self.provider_settings.ots_instance_name
)
ots_instance_name = self.provider_settings.ots_instance_name
if ots_endpoint is None:
ots_endpoint = cfg.get_ots_endpoint(ots_instance_name)

return AgentStorageClient(
access_key_id=cfg.get_access_key_id(),
access_key_secret=cfg.get_access_key_secret(),
sts_token=cfg.get_security_token(),
ots_endpoint=ots_endpoint,
ots_instance_name=self.provider_settings.ots_instance_name,
ots_instance_name=ots_instance_name,
)

def _build_frontend_ots_endpoint(
self, config: Optional[Config] = None
) -> str:
cfg = Config.with_configs(self.config, config)
return f"http://ots-{cfg.get_region_id()}.aliyuncs.com"
Comment on lines +679 to +683

def _can_fallback_to_frontend_ots_endpoint(
self, config: Optional[Config] = None
) -> bool:
cfg = Config.with_configs(self.config, config)
return not cfg.get_use_vpc_endpoint()

def _build_retrieval_configuration(
self, filter: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -838,7 +853,22 @@ async def retrieve_async(
request["retrievalConfiguration"] = retrieval_config

logger.debug(f"OTS retrieve request: {request}")
response = client.retrieve(request)
try:
response = client.retrieve(request)
except Exception as instance_error:
if not self._can_fallback_to_frontend_ots_endpoint(config):
raise
Comment on lines 855 to +860
frontend_endpoint = self._build_frontend_ots_endpoint(config)
logger.warning(
"Failed to retrieve from OTS knowledge base "
f"'{self.knowledge_base_name}' with instance endpoint, "
f"fallback to frontend endpoint {frontend_endpoint}: "
f"{instance_error}"
)
client = self._build_agent_storage_client(
config, ots_endpoint=frontend_endpoint
)
response = client.retrieve(request)
logger.debug(f"OTS retrieve response: {response}")

return self._parse_retrieve_response(response, query)
Expand Down
59 changes: 52 additions & 7 deletions agentrun/knowledgebase/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,12 +927,15 @@ def __init__(
self.retrieve_settings = retrieve_settings

def _build_agent_storage_client(
self, config: Optional[Config] = None
self,
config: Optional[Config] = None,
ots_endpoint: Optional[str] = None,
) -> AgentStorageClient:
"""构建 AgentStorageClient / Build AgentStorageClient

Args:
config: 配置 / Configuration
ots_endpoint: OTS 访问域名 / OTS endpoint

Returns:
AgentStorageClient: OTS 存储客户端
Expand All @@ -941,18 +944,30 @@ def _build_agent_storage_client(
raise ValueError("provider_settings is required for OTS retrieval")

cfg = Config.with_configs(self.config, config)
ots_endpoint = cfg.get_ots_endpoint(
self.provider_settings.ots_instance_name
)
ots_instance_name = self.provider_settings.ots_instance_name
if ots_endpoint is None:
ots_endpoint = cfg.get_ots_endpoint(ots_instance_name)

return AgentStorageClient(
access_key_id=cfg.get_access_key_id(),
access_key_secret=cfg.get_access_key_secret(),
sts_token=cfg.get_security_token(),
ots_endpoint=ots_endpoint,
ots_instance_name=self.provider_settings.ots_instance_name,
ots_instance_name=ots_instance_name,
)

def _build_frontend_ots_endpoint(
self, config: Optional[Config] = None
) -> str:
cfg = Config.with_configs(self.config, config)
return f"http://ots-{cfg.get_region_id()}.aliyuncs.com"

def _can_fallback_to_frontend_ots_endpoint(
self, config: Optional[Config] = None
) -> bool:
cfg = Config.with_configs(self.config, config)
return not cfg.get_use_vpc_endpoint()

def _build_retrieval_configuration(
self, filter: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -1118,7 +1133,22 @@ async def retrieve_async(
request["retrievalConfiguration"] = retrieval_config

logger.debug(f"OTS retrieve request: {request}")
response = client.retrieve(request)
try:
response = client.retrieve(request)
except Exception as instance_error:
if not self._can_fallback_to_frontend_ots_endpoint(config):
raise
frontend_endpoint = self._build_frontend_ots_endpoint(config)
logger.warning(
"Failed to retrieve from OTS knowledge base "
f"'{self.knowledge_base_name}' with instance endpoint, "
f"fallback to frontend endpoint {frontend_endpoint}: "
f"{instance_error}"
)
client = self._build_agent_storage_client(
config, ots_endpoint=frontend_endpoint
)
response = client.retrieve(request)
logger.debug(f"OTS retrieve response: {response}")

return self._parse_retrieve_response(response, query)
Expand Down Expand Up @@ -1174,7 +1204,22 @@ def retrieve(
request["retrievalConfiguration"] = retrieval_config

logger.debug(f"OTS retrieve request: {request}")
response = client.retrieve(request)
try:
response = client.retrieve(request)
except Exception as instance_error:
if not self._can_fallback_to_frontend_ots_endpoint(config):
raise
frontend_endpoint = self._build_frontend_ots_endpoint(config)
logger.warning(
"Failed to retrieve from OTS knowledge base "
f"'{self.knowledge_base_name}' with instance endpoint, "
f"fallback to frontend endpoint {frontend_endpoint}: "
f"{instance_error}"
)
client = self._build_agent_storage_client(
config, ots_endpoint=frontend_endpoint
)
response = client.retrieve(request)
logger.debug(f"OTS retrieve response: {response}")

return self._parse_retrieve_response(response, query)
Expand Down
2 changes: 1 addition & 1 deletion agentrun/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def get_ots_endpoint(self, instance_name: str) -> str:
return (
f"https://{instance_name}.{region_id}.vpc.tablestore.aliyuncs.com"
)
return f"http://ots-{region_id}.aliyuncs.com"
return f"https://{instance_name}.{region_id}.ots.aliyuncs.com"

def get_use_vpc_endpoint(self) -> bool:
"""知识库检索是否使用 VPC 内网 endpoint"""
Expand Down
Loading
Loading