transformers - 💡(How to fix) Fix Improving CLI Serving Code Structure with Class-Based FastAPI Patterns [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45696Fetched 2026-04-30 06:18:25
View on GitHub
Comments
1
Participants
1
Timeline
6
Reactions
0
Participants
Timeline (top)
labeled ×2closed ×1commented ×1mentioned ×1

Error Message

import uuid from contextlib import asynccontextmanager

from ...utils import logging from ...utils.import_utils import is_serve_available

if is_serve_available(): from fastapi import FastAPI, Request, APIRouter, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from starlette.middleware.base import BaseHTTPMiddleware # For this example, I'm using a minimal library (fastapi-cbx) for the class-based routing. # The core architectural pattern of grouping endpoints in a class is independent of the implementation. from cbx import cbr

from .chat_completion import ChatCompletionHandler from .completion import CompletionHandler from .model_manager import ModelManager from .response import ResponseHandler from .transcription import TranscriptionHandler from .utils import X_REQUEST_ID

@cbr(router=APIRouter()) class OpenAI:

logger = logging.get_logger(__qualname__)

def __init__(self, model_manager: ModelManager, chat_handler: ChatCompletionHandler, completion_handler: CompletionHandler, response_handler: ResponseHandler, transcription_handler: TranscriptionHandler):
    self.model_manager = model_manager
    self.chat_handler = chat_handler
    self.completion_handler = completion_handler
    self.response_handler = response_handler
    self.transcription_handler = transcription_handler

@staticmethod
def create_app(model_manager: ModelManager, chat_handler: ChatCompletionHandler, completion_handler: CompletionHandler, response_handler: ResponseHandler, transcription_handler: TranscriptionHandler, enable_cors: bool = False):
    service = OpenAI(
        model_manager,
        chat_handler,
        completion_handler,
        response_handler,
        transcription_handler
    )
    app = FastAPI(lifespan=service.lifespan)
    if enable_cors:
        app.add_middleware(
            CORSMiddleware,
            allow_origins=["*"],
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )
        service.logger.warning_once(
            "CORS allow origin is set to `*`. Not recommended for production."
        )

    app.add_middleware(
        BaseHTTPMiddleware,
        dispatch=OpenAI.request_id_middleware
    )
    app.include_router(OpenAI.router)
    return app

@asynccontextmanager
async def lifespan(self, app: FastAPI):
    yield
    self.model_manager.shutdown()

@staticmethod
async def request_id_middleware(request: Request, call_next):
    """Get or set the request ID in the header."""
    request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
    request.state.request_id = request_id
    response = await call_next(request)
    response.headers[X_REQUEST_ID] = request_id
    return response

@cbr.post("/v1/chat/completions")
async def chat_completions(self, request: Request, body: dict):
    return await self.chat_handler.handle_request(body, request.state.request_id)

@cbr.post("/v1/completions")
async def completions(self, request: Request, body: dict):
    return await self.completion_handler.handle_request(body, request.state.request_id)

@cbr.post("/v1/responses")
async def responses(self, request: Request, body: dict):
    return await self.response_handler.handle_request(body, request.state.request_id)

@cbr.post("/v1/audio/transcriptions")
async def audio_transcriptions(self, request: Request):
    return await self.transcription_handler.handle_request(request)

@cbr.post("/load_model")
async def load_model(self, body: dict):
    model = body.get("model")
    if model is None:
        raise HTTPException(
            status_code=422,
            detail="Missing `model` field in the request body."
        )
    model_id_and_revision = self.model_manager.process_model_name(model)
    return StreamingResponse(
        self.model_manager.load_model_streaming(model_id_and_revision),
        media_type="text/event-stream"
    )

@cbr.post("/reset")
def reset(self):
    self.model_manager.shutdown()
    return JSONResponse({"status": "ok"})

@cbr.get("/v1/models")
@cbr.options("/v1/models")
def list_models(self):
    return JSONResponse({"object": "list", "data": self.model_manager.get_gen_models()})

@cbr.get("/health")
@staticmethod
def health():
    return JSONResponse({"status": "ok"})

Fix Action

Fix / Workaround

app.add_middleware( BaseHTTPMiddleware, dispatch=OpenAI.request_id_middleware ) app.include_router(OpenAI.router) return app

Code Example

import uuid
from contextlib import asynccontextmanager

from ...utils import logging
from ...utils.import_utils import is_serve_available

if is_serve_available():
    from fastapi import FastAPI, Request, APIRouter, HTTPException
    from fastapi.middleware.cors import CORSMiddleware
    from fastapi.responses import JSONResponse, StreamingResponse
    from starlette.middleware.base import BaseHTTPMiddleware
    # For this example, I'm using a minimal library (`fastapi-cbx`) for the class-based routing.
    # The core architectural pattern of grouping endpoints in a class is independent of the implementation.
    from cbx import cbr

from .chat_completion import ChatCompletionHandler
from .completion import CompletionHandler
from .model_manager import ModelManager
from .response import ResponseHandler
from .transcription import TranscriptionHandler
from .utils import X_REQUEST_ID


@cbr(router=APIRouter())
class OpenAI:

    logger = logging.get_logger(__qualname__)

    def __init__(self, model_manager: ModelManager, chat_handler: ChatCompletionHandler, completion_handler: CompletionHandler, response_handler: ResponseHandler, transcription_handler: TranscriptionHandler):
        self.model_manager = model_manager
        self.chat_handler = chat_handler
        self.completion_handler = completion_handler
        self.response_handler = response_handler
        self.transcription_handler = transcription_handler

    @staticmethod
    def create_app(model_manager: ModelManager, chat_handler: ChatCompletionHandler, completion_handler: CompletionHandler, response_handler: ResponseHandler, transcription_handler: TranscriptionHandler, enable_cors: bool = False):
        service = OpenAI(
            model_manager,
            chat_handler,
            completion_handler,
            response_handler,
            transcription_handler
        )
        app = FastAPI(lifespan=service.lifespan)
        if enable_cors:
            app.add_middleware(
                CORSMiddleware,
                allow_origins=["*"],
                allow_credentials=True,
                allow_methods=["*"],
                allow_headers=["*"],
            )
            service.logger.warning_once(
                "CORS allow origin is set to `*`. Not recommended for production."
            )

        app.add_middleware(
            BaseHTTPMiddleware,
            dispatch=OpenAI.request_id_middleware
        )
        app.include_router(OpenAI.router)
        return app

    @asynccontextmanager
    async def lifespan(self, app: FastAPI):
        yield
        self.model_manager.shutdown()

    @staticmethod
    async def request_id_middleware(request: Request, call_next):
        """Get or set the request ID in the header."""
        request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
        request.state.request_id = request_id
        response = await call_next(request)
        response.headers[X_REQUEST_ID] = request_id
        return response

    @cbr.post("/v1/chat/completions")
    async def chat_completions(self, request: Request, body: dict):
        return await self.chat_handler.handle_request(body, request.state.request_id)

    @cbr.post("/v1/completions")
    async def completions(self, request: Request, body: dict):
        return await self.completion_handler.handle_request(body, request.state.request_id)

    @cbr.post("/v1/responses")
    async def responses(self, request: Request, body: dict):
        return await self.response_handler.handle_request(body, request.state.request_id)

    @cbr.post("/v1/audio/transcriptions")
    async def audio_transcriptions(self, request: Request):
        return await self.transcription_handler.handle_request(request)

    @cbr.post("/load_model")
    async def load_model(self, body: dict):
        model = body.get("model")
        if model is None:
            raise HTTPException(
                status_code=422,
                detail="Missing `model` field in the request body."
            )
        model_id_and_revision = self.model_manager.process_model_name(model)
        return StreamingResponse(
            self.model_manager.load_model_streaming(model_id_and_revision),
            media_type="text/event-stream"
        )

    @cbr.post("/reset")
    def reset(self):
        self.model_manager.shutdown()
        return JSONResponse({"status": "ok"})

    @cbr.get("/v1/models")
    @cbr.options("/v1/models")
    def list_models(self):
        return JSONResponse({"object": "list", "data": self.model_manager.get_gen_models()})

    @cbr.get("/health")
    @staticmethod
    def health():
        return JSONResponse({"status": "ok"})
RAW_BUFFERClick to expand / collapse

Feature request

This proposal suggests refactoring the CLI serving code (transformers/cli/serving/server.py) from a function-based to a class-based architecture, using a pattern that better organizes related endpoints, clarifies resource lifecycle management, and improves long-term maintainability.

Motivation

The current build_server function in server.py is well-engineered and works reliably. It correctly uses closures to share heavy resources (model managers, handlers) across all endpoints.

However, as ML serving applications grow in complexity and teams scale, a class-based structure could offer advantages for long-term maintenance:

  1. Explicit Architecture: A class makes dependencies and relationships immediately clear to new contributors
  2. Better Organization: Related endpoints (chat, completions, embeddings, health checks) are naturally grouped
  3. Standardized Patterns: Aligns with how many production teams structure Python services, especially in ML/LLM serving
  4. Future-proofing: Provides a clearer foundation for adding shared middleware, state management, or additional endpoints

This is not about fixing a bug, but about exploring an alternative architectural pattern that could benefit the project as the serving layer evolves.

Your contribution

Yes, I've already implemented a complete refactoring as a proof of concept. Here's the class-based pattern I propose:

import uuid
from contextlib import asynccontextmanager

from ...utils import logging
from ...utils.import_utils import is_serve_available

if is_serve_available():
    from fastapi import FastAPI, Request, APIRouter, HTTPException
    from fastapi.middleware.cors import CORSMiddleware
    from fastapi.responses import JSONResponse, StreamingResponse
    from starlette.middleware.base import BaseHTTPMiddleware
    # For this example, I'm using a minimal library (`fastapi-cbx`) for the class-based routing.
    # The core architectural pattern of grouping endpoints in a class is independent of the implementation.
    from cbx import cbr

from .chat_completion import ChatCompletionHandler
from .completion import CompletionHandler
from .model_manager import ModelManager
from .response import ResponseHandler
from .transcription import TranscriptionHandler
from .utils import X_REQUEST_ID


@cbr(router=APIRouter())
class OpenAI:

    logger = logging.get_logger(__qualname__)

    def __init__(self, model_manager: ModelManager, chat_handler: ChatCompletionHandler, completion_handler: CompletionHandler, response_handler: ResponseHandler, transcription_handler: TranscriptionHandler):
        self.model_manager = model_manager
        self.chat_handler = chat_handler
        self.completion_handler = completion_handler
        self.response_handler = response_handler
        self.transcription_handler = transcription_handler

    @staticmethod
    def create_app(model_manager: ModelManager, chat_handler: ChatCompletionHandler, completion_handler: CompletionHandler, response_handler: ResponseHandler, transcription_handler: TranscriptionHandler, enable_cors: bool = False):
        service = OpenAI(
            model_manager,
            chat_handler,
            completion_handler,
            response_handler,
            transcription_handler
        )
        app = FastAPI(lifespan=service.lifespan)
        if enable_cors:
            app.add_middleware(
                CORSMiddleware,
                allow_origins=["*"],
                allow_credentials=True,
                allow_methods=["*"],
                allow_headers=["*"],
            )
            service.logger.warning_once(
                "CORS allow origin is set to `*`. Not recommended for production."
            )

        app.add_middleware(
            BaseHTTPMiddleware,
            dispatch=OpenAI.request_id_middleware
        )
        app.include_router(OpenAI.router)
        return app

    @asynccontextmanager
    async def lifespan(self, app: FastAPI):
        yield
        self.model_manager.shutdown()

    @staticmethod
    async def request_id_middleware(request: Request, call_next):
        """Get or set the request ID in the header."""
        request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
        request.state.request_id = request_id
        response = await call_next(request)
        response.headers[X_REQUEST_ID] = request_id
        return response

    @cbr.post("/v1/chat/completions")
    async def chat_completions(self, request: Request, body: dict):
        return await self.chat_handler.handle_request(body, request.state.request_id)

    @cbr.post("/v1/completions")
    async def completions(self, request: Request, body: dict):
        return await self.completion_handler.handle_request(body, request.state.request_id)

    @cbr.post("/v1/responses")
    async def responses(self, request: Request, body: dict):
        return await self.response_handler.handle_request(body, request.state.request_id)

    @cbr.post("/v1/audio/transcriptions")
    async def audio_transcriptions(self, request: Request):
        return await self.transcription_handler.handle_request(request)

    @cbr.post("/load_model")
    async def load_model(self, body: dict):
        model = body.get("model")
        if model is None:
            raise HTTPException(
                status_code=422,
                detail="Missing `model` field in the request body."
            )
        model_id_and_revision = self.model_manager.process_model_name(model)
        return StreamingResponse(
            self.model_manager.load_model_streaming(model_id_and_revision),
            media_type="text/event-stream"
        )

    @cbr.post("/reset")
    def reset(self):
        self.model_manager.shutdown()
        return JSONResponse({"status": "ok"})

    @cbr.get("/v1/models")
    @cbr.options("/v1/models")
    def list_models(self):
        return JSONResponse({"object": "list", "data": self.model_manager.get_gen_models()})

    @cbr.get("/health")
    @staticmethod
    def health():
        return JSONResponse({"status": "ok"})

extent analysis

TL;DR

Refactor the CLI serving code to a class-based architecture to improve maintainability and organization.

Guidance

  • Review the proposed class-based pattern in the OpenAI class to understand how it organizes related endpoints and manages resources.
  • Consider the benefits of explicit architecture, better organization, and standardized patterns in the proposed refactoring.
  • Evaluate how the class-based structure aligns with production teams' structuring of Python services, especially in ML/LLM serving.
  • Assess the potential for future-proofing and adding shared middleware, state management, or additional endpoints with the proposed architecture.

Example

The provided code snippet demonstrates a class-based pattern using the OpenAI class, which includes methods for handling different endpoints, such as chat_completions, completions, and responses.

Notes

The proposed refactoring is not about fixing a bug, but rather about exploring an alternative architectural pattern to benefit the project's long-term maintenance and evolution.

Recommendation

Apply the proposed class-based architecture to improve the organization and maintainability of the CLI serving code, as it provides a clearer foundation for future development and scaling.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING