Skip to content

vllm.entrypoints.pooling.base.serving

PoolingServing

Source code in vllm/entrypoints/pooling/base/serving.py
class PoolingServing:
    request_id_prefix: ClassVar[str]

    def __init__(
        self,
        engine_client: EngineClient,
        models: OpenAIServingModels,
        *,
        request_logger: RequestLogger | None,
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        trust_request_chat_template: bool = False,
        return_tokens_as_token_ids: bool = False,
        log_error_stack: bool = False,
    ):
        super().__init__()
        self.engine_client = engine_client
        self.models = models
        self.model_config = models.model_config
        self.max_model_len = self.model_config.max_model_len
        self.request_logger = request_logger
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
        self.log_error_stack = True

        self.io_processor = self.init_io_processor(
            model_config=models.model_config,
            renderer=models.renderer,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            trust_request_chat_template=trust_request_chat_template,
        )

    def init_io_processor(
        self,
        model_config: ModelConfig,
        renderer: BaseRenderer,
        *,
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        trust_request_chat_template: bool = False,
    ) -> PoolingIOProcessor:
        raise NotImplementedError

    async def __call__(
        self,
        request: OpenAIBaseModel,
        raw_request: Request,
    ):
        try:
            model_name = self.models.model_name()
            request_id = (
                f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
            )

            await self._check_model(request, raw_request)

            ctx = ServeContext(
                request=request,
                raw_request=raw_request,
                model_name=model_name,
                request_id=request_id,
            )

            self._validate_request(ctx)
            ctx.lora_request = self._maybe_get_adapters(ctx.request)
            await self._preprocess(ctx)
            await self._prepare_generators(ctx)
            await self._collect_batch(ctx)
            generator = await self._build_response(ctx)
        except Exception as e:
            generator = self.create_error_response(e)

        if isinstance(generator, ErrorResponse):
            return JSONResponse(
                content=generator.model_dump(), status_code=generator.error.code
            )
        elif isinstance(generator, OpenAIBaseModel):
            return JSONResponse(content=generator.model_dump())

        assert_never(generator)

    async def _preprocess(
        self,
        ctx: ServeContext,
    ):
        ctx.engine_prompts = await self.io_processor.pre_process(ctx.request)

    async def _get_trace_headers(
        self,
        headers: Headers,
    ) -> Mapping[str, str] | None:
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()

        if is_tracing_enabled:
            return extract_trace_headers(headers)

        if contains_trace_headers(headers):
            log_tracing_disabled_warning()

        return None

    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ):
        """Schedule the request and get the result generator."""
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []

        trace_headers = (
            None
            if ctx.raw_request is None
            else await self._get_trace_headers(ctx.raw_request.headers)
        )

        pooling_params = self.io_processor.create_pooling_params(ctx.request)

        for i, engine_prompt in enumerate(ctx.engine_prompts):
            request_id_item = f"{ctx.request_id}-{i}"

            self._log_inputs(
                request_id_item,
                engine_prompt,
                params=pooling_params,
                lora_request=ctx.lora_request,
            )

            generator = self.engine_client.encode(
                engine_prompt,
                pooling_params,
                request_id_item,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )

            generators.append(generator)

        ctx.result_generator = merge_async_iterators(*generators)

    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> ErrorResponse | None:
        """Collect batch results from the result generator."""

        num_prompts = len(ctx.engine_prompts)
        final_res_batch: list[PoolingRequestOutput | None]
        final_res_batch = [None] * num_prompts

        async for i, res in ctx.result_generator:
            final_res_batch[i] = res

        if None in final_res_batch:
            raise ValueError("Failed to generate results for all prompts")

        ctx.final_res_batch = final_res_batch

    async def _build_response(
        self,
        ctx: ServeContext,
    ) -> AnyResponse:
        raise NotImplementedError

    #########################################################
    #########################################################

    @staticmethod
    def _base_request_id(
        raw_request: Request | None, default: str | None = None
    ) -> str | None:
        """Pulls the request id to use from a header, if provided"""
        if raw_request is not None and (
            (req_id := raw_request.headers.get("X-Request-Id")) is not None
        ):
            return req_id

        return random_uuid() if default is None else default

    def _is_model_supported(self, model_name: str | None) -> bool:
        if not model_name:
            return True
        return self.models.is_base_model(model_name)

    async def _check_model(
        self,
        request: AnyRequest,
        raw_request: Request,
    ) -> ErrorResponse | None:
        error_response = None

        if self._is_model_supported(request.model):
            return None
        if request.model in self.models.lora_requests:
            return None
        if (
            envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
            and request.model
            and (load_result := await self.models.resolve_lora(request.model))
        ):
            if isinstance(load_result, LoRARequest):
                return None
            if (
                isinstance(load_result, ErrorResponse)
                and load_result.error.code == HTTPStatus.BAD_REQUEST.value
            ):
                raise ValueError(load_result.error.message)
        return None

    def _validate_request(self, ctx: ServeContext) -> None:
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)

        if (
            truncate_prompt_tokens is not None
            and truncate_prompt_tokens > self.max_model_len
        ):
            raise ValueError(
                "truncate_prompt_tokens value is "
                "greater than max_model_len."
                " Please, select a smaller truncation size."
            )
        return None

    def create_error_response(
        self,
        message: str | Exception,
        err_type: str | None = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
        param: str | None = None,
    ) -> ErrorResponse:
        return create_error_response(
            message=message,
            err_type=err_type,
            status_code=status_code,
            param=param,
            log_error_stack=self.log_error_stack,
        )

    def _maybe_get_adapters(
        self,
        request: AnyRequest,
        supports_default_mm_loras: bool = False,
    ) -> LoRARequest | None:
        if request.model in self.models.lora_requests:
            return self.models.lora_requests[request.model]

        # Currently only support default modality specific loras
        # if we have exactly one lora matched on the request.
        if supports_default_mm_loras:
            default_mm_lora = self._get_active_default_mm_loras(request)
            if default_mm_lora is not None:
                return default_mm_lora

        if self._is_model_supported(request.model):
            return None

        # if _check_model has been called earlier, this will be unreachable
        raise ValueError(f"The model `{request.model}` does not exist.")

    def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
        """Determine if there are any active default multimodal loras."""
        # TODO: Currently this is only enabled for chat completions
        # to be better aligned with only being enabled for .generate
        # when run offline. It would be nice to support additional
        # tasks types in the future.
        message_types = self._get_message_types(request)
        default_mm_loras = set()

        for lora in self.models.lora_requests.values():
            # Best effort match for default multimodal lora adapters;
            # There is probably a better way to do this, but currently
            # this matches against the set of 'types' in any content lists
            # up until '_', e.g., to match audio_url -> audio
            if lora.lora_name in message_types:
                default_mm_loras.add(lora)

        # Currently only support default modality specific loras if
        # we have exactly one lora matched on the request.
        if len(default_mm_loras) == 1:
            return default_mm_loras.pop()
        return None

    def _get_message_types(self, request: AnyRequest) -> set[str]:
        """Retrieve the set of types from message content dicts up
        until `_`; we use this to match potential multimodal data
        with default per modality loras.
        """
        message_types: set[str] = set()

        if not hasattr(request, "messages"):
            return message_types

        messages = request.messages
        if messages is None or isinstance(messages, (str, bytes)):
            return message_types

        for message in messages:
            if (
                isinstance(message, dict)
                and "content" in message
                and isinstance(message["content"], list)
            ):
                for content_dict in message["content"]:
                    if "type" in content_dict:
                        message_types.add(content_dict["type"].split("_")[0])
        return message_types

    def _log_inputs(
        self,
        request_id: str,
        inputs: PromptType,
        params: SamplingParams | PoolingParams | BeamSearchParams | None,
        lora_request: LoRARequest | None,
    ) -> None:
        if self.request_logger is None:
            return

        prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)

        self.request_logger.log_inputs(
            request_id,
            prompt,
            prompt_token_ids,
            prompt_embeds,
            params=params,
            lora_request=lora_request,
        )

_base_request_id staticmethod

_base_request_id(
    raw_request: Request | None, default: str | None = None
) -> str | None

Pulls the request id to use from a header, if provided

Source code in vllm/entrypoints/pooling/base/serving.py
@staticmethod
def _base_request_id(
    raw_request: Request | None, default: str | None = None
) -> str | None:
    """Pulls the request id to use from a header, if provided"""
    if raw_request is not None and (
        (req_id := raw_request.headers.get("X-Request-Id")) is not None
    ):
        return req_id

    return random_uuid() if default is None else default

_collect_batch async

_collect_batch(ctx: ServeContext) -> ErrorResponse | None

Collect batch results from the result generator.

Source code in vllm/entrypoints/pooling/base/serving.py
async def _collect_batch(
    self,
    ctx: ServeContext,
) -> ErrorResponse | None:
    """Collect batch results from the result generator."""

    num_prompts = len(ctx.engine_prompts)
    final_res_batch: list[PoolingRequestOutput | None]
    final_res_batch = [None] * num_prompts

    async for i, res in ctx.result_generator:
        final_res_batch[i] = res

    if None in final_res_batch:
        raise ValueError("Failed to generate results for all prompts")

    ctx.final_res_batch = final_res_batch

_get_active_default_mm_loras

_get_active_default_mm_loras(
    request: AnyRequest,
) -> LoRARequest | None

Determine if there are any active default multimodal loras.

Source code in vllm/entrypoints/pooling/base/serving.py
def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
    """Determine if there are any active default multimodal loras."""
    # TODO: Currently this is only enabled for chat completions
    # to be better aligned with only being enabled for .generate
    # when run offline. It would be nice to support additional
    # tasks types in the future.
    message_types = self._get_message_types(request)
    default_mm_loras = set()

    for lora in self.models.lora_requests.values():
        # Best effort match for default multimodal lora adapters;
        # There is probably a better way to do this, but currently
        # this matches against the set of 'types' in any content lists
        # up until '_', e.g., to match audio_url -> audio
        if lora.lora_name in message_types:
            default_mm_loras.add(lora)

    # Currently only support default modality specific loras if
    # we have exactly one lora matched on the request.
    if len(default_mm_loras) == 1:
        return default_mm_loras.pop()
    return None

_get_message_types

_get_message_types(request: AnyRequest) -> set[str]

Retrieve the set of types from message content dicts up until _; we use this to match potential multimodal data with default per modality loras.

Source code in vllm/entrypoints/pooling/base/serving.py
def _get_message_types(self, request: AnyRequest) -> set[str]:
    """Retrieve the set of types from message content dicts up
    until `_`; we use this to match potential multimodal data
    with default per modality loras.
    """
    message_types: set[str] = set()

    if not hasattr(request, "messages"):
        return message_types

    messages = request.messages
    if messages is None or isinstance(messages, (str, bytes)):
        return message_types

    for message in messages:
        if (
            isinstance(message, dict)
            and "content" in message
            and isinstance(message["content"], list)
        ):
            for content_dict in message["content"]:
                if "type" in content_dict:
                    message_types.add(content_dict["type"].split("_")[0])
    return message_types

_prepare_generators async

_prepare_generators(ctx: ServeContext)

Schedule the request and get the result generator.

Source code in vllm/entrypoints/pooling/base/serving.py
async def _prepare_generators(
    self,
    ctx: ServeContext,
):
    """Schedule the request and get the result generator."""
    generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []

    trace_headers = (
        None
        if ctx.raw_request is None
        else await self._get_trace_headers(ctx.raw_request.headers)
    )

    pooling_params = self.io_processor.create_pooling_params(ctx.request)

    for i, engine_prompt in enumerate(ctx.engine_prompts):
        request_id_item = f"{ctx.request_id}-{i}"

        self._log_inputs(
            request_id_item,
            engine_prompt,
            params=pooling_params,
            lora_request=ctx.lora_request,
        )

        generator = self.engine_client.encode(
            engine_prompt,
            pooling_params,
            request_id_item,
            lora_request=ctx.lora_request,
            trace_headers=trace_headers,
            priority=getattr(ctx.request, "priority", 0),
        )

        generators.append(generator)

    ctx.result_generator = merge_async_iterators(*generators)