Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 75 additions & 29 deletions app/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os.path
import app.api.globals as cms_globals

from typing import Dict, Any, Optional, Union, Type
from typing import Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor
from anyio.lowlevel import RunVar
from anyio import CapacityLimiter
Expand All @@ -18,9 +18,14 @@
from app.api.auth.db import make_sure_db_and_tables
from app.api.auth.users import Props
from app.api.dependencies import ModelServiceDep
from app.api.utils import add_exception_handlers, add_rate_limiter, init_vllm_engine
from app.api.utils import (
add_exception_handlers,
add_rate_limiter,
init_vllm_engine,
init_sglang_engine,
ForwardedPrefixMiddleware,
)
from app.config import Settings
from app.domain import Tags, TagsStreamable, TagsGenerative
from app.management.tracker_client import TrackerClient
from app.utils import get_settings, unpack_model_data_package, get_model_data_package_base_name
from app.exception import ConfigurationException
Expand All @@ -29,7 +34,6 @@
logging.getLogger("asyncio").setLevel(logging.ERROR)
logger = logging.getLogger("cms")


def get_model_server(config: Settings, msd_overwritten: Optional[ModelServiceDep] = None) -> FastAPI:
"""
Initialises a FastAPI app instance configured for the CMS model service.
Expand Down Expand Up @@ -111,7 +115,10 @@ def get_stream_server(config: Settings, msd_overwritten: Optional[ModelServiceDe
return app


def get_generative_server(config: Settings, msd_overwritten: Optional[ModelServiceDep] = None) -> FastAPI:
def get_generative_server(
config: Settings,
msd_overwritten: Optional[ModelServiceDep] = None,
) -> FastAPI:
"""
Initialises a FastAPI instance configured for a generative server.

Expand All @@ -134,6 +141,9 @@ def get_generative_server(config: Settings, msd_overwritten: Optional[ModelServi
if config.ENABLE_TRAINING_APIS == "true":
app = _load_supervised_training_router(app)
logger.debug("Supervised training router loaded")
if config.DISABLE_UNSUPERVISED_TRAINING != "true":
app = _load_unsupervised_training_router(app)
logger.debug("Unsupervised training router loaded")
app = _load_training_operations(app)

if config.AUTH_USER_ENABLED == "true":
Expand All @@ -147,25 +157,71 @@ def get_generative_server(config: Settings, msd_overwritten: Optional[ModelServi

return app

def get_vllm_server(config: Settings, model_package_path: str, model_name: str, log_level: str = "info") -> FastAPI:
def get_vllm_server(
config: Settings,
model_package_path: str,
model_name: str,
log_level: str = "info",
server_args: Optional[str] = None,
) -> FastAPI:
"""
Initialises a FastAPI instance configured for a vLLM server.

Args:
config (Settings): The CMS configuration.
model_package_path (str): The path to the model package file.
model_name (str): The name of the model.
log_level (str): The log level for the VLLM engine. Default to "info".
log_level (str): The log level for the vLLM engine. Default to "info".
server_args (Optional[str]): The arguments to pass to the vLLM engine.

Returns:
FastAPI: A FastAPI app instance.
"""

app = _get_app(None, streamable=False)
model_dir_path = os.path.join(
os.path.dirname(model_package_path), get_model_data_package_base_name(model_package_path)
)
if unpack_model_data_package(model_package_path, model_dir_path):
async def _startup() -> None:
await init_vllm_engine(app, config, model_dir_path, model_name, log_level, server_args)

app.add_event_handler("startup", _startup)
else:
raise ConfigurationException(f"Model package archive format is not supported: {model_package_path}")

return app


def get_sglang_server(
config: Settings,
model_package_path: str,
model_name: str,
log_level: str = "info",
server_args: Optional[str] = None,
) -> FastAPI:
"""
Initialises a FastAPI instance configured for an SGLang server.

Args:
config (Settings): The CMS configuration.
model_package_path (str): The path to the model package file.
model_name (str): The name of the model.
log_level (str): The log level for the SGLang engine. Default to "info".
server_args (Optional[str]): The arguments to pass to the SGLang engine.
Returns:
FastAPI: A FastAPI app instance.
"""

app = _get_app(None, streamable=False)
model_dir_path = os.path.join(os.path.dirname(model_package_path), get_model_data_package_base_name(model_package_path))
model_dir_path = os.path.join(
os.path.dirname(model_package_path), get_model_data_package_base_name(model_package_path)
)
if unpack_model_data_package(model_package_path, model_dir_path):
loop = asyncio.get_event_loop()
app = loop.run_until_complete(init_vllm_engine(app, model_dir_path, model_name, log_level))
async def _startup() -> None:
await init_sglang_engine(app, config, model_dir_path, model_name, log_level, server_args)

app.add_event_handler("startup", _startup)
else:
raise ConfigurationException(f"Model package archive format is not supported: {model_package_path}")

Expand Down Expand Up @@ -204,32 +260,21 @@ def _get_app(
generative: bool = False,
) -> FastAPI:
config = get_settings()
tags: Union[Type[Tags], Type[TagsStreamable], Type[TagsGenerative]]
if generative:
tags = TagsGenerative
elif streamable:
tags = TagsStreamable
else:
tags = Tags
tags_metadata = [{
"name": tag.name,
"description": tag.value
} for tag in tags]

app = FastAPI(
title="CogStack ModelServe",
summary="A model serving and governance system for CogStack NLP solutions",
docs_url=None,
redoc_url=None,
debug=(config.DEBUG == "true"),
openapi_tags=tags_metadata,
)

app.add_middleware(ForwardedPrefixMiddleware) # type: ignore
add_exception_handlers(app)

instrumentator = None
if not generative:
instrumentator = Instrumentator(
excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"]
).instrument(app)
instrumentator = Instrumentator(
excluded_handlers=["/docs", "/redoc", "/metrics", "/openapi.json", "/favicon.ico", "none"]
).instrument(app)

if msd_overwritten is not None:
cms_globals.model_service_dep = msd_overwritten
Expand Down Expand Up @@ -279,8 +324,9 @@ async def redoc_doc(req: Request) -> HTMLResponse:
)

@app.get("/", include_in_schema=False)
async def root_redirect() -> RedirectResponse:
return RedirectResponse(url="/docs")
async def root_redirect(req: Request) -> RedirectResponse:
root_path = req.scope.get("root_path", "").rstrip("/")
return RedirectResponse(url=f"{req.url.scheme}://{req.url.netloc}{root_path}/docs")

@app.on_event("shutdown")
async def on_shutdown() -> None:
Expand Down
4 changes: 1 addition & 3 deletions app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from fastapi import HTTPException, Query
from starlette.status import HTTP_400_BAD_REQUEST

from typing import Optional
from app.config import Settings
from app.domain import ModelType
Expand All @@ -14,11 +13,10 @@
from app.model_services.base import AbstractModelService
from app.management.model_manager import ModelManager

TRACKING_ID_REGEX = re.compile(r"^[a-zA-Z0-9][\w\-]{0,255}$")

TRACKING_ID_REGEX = re.compile(r"^[a-zA-Z0-9][\w\-]{0,255}$")
logger = logging.getLogger("cms")


class ModelServiceDep(object):
"""Dependency class for injecting the CMS model service based on the given model type."""

Expand Down
Loading
Loading