Skip to content

vllm.entrypoints.serve.render.serving

OpenAIServingRender

Source code in vllm/entrypoints/serve/render/serving.py
class OpenAIServingRender:
    def __init__(
        self,
        model_config: ModelConfig,
        renderer: BaseRenderer,
        io_processor: Any,
        served_model_names: list[str],
        *,
        request_logger: RequestLogger | None,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
        trust_request_chat_template: bool = False,
        enable_auto_tools: bool = False,
        exclude_tools_when_tool_choice_none: bool = False,
        tool_parser: str | None = None,
        default_chat_template_kwargs: dict[str, Any] | None = None,
        log_error_stack: bool = False,
    ) -> None:
        self.model_config = model_config
        self.renderer = renderer
        self.io_processor = io_processor
        self.served_model_names = served_model_names
        self.request_logger = request_logger
        self.chat_template = chat_template
        self.chat_template_content_format: ChatTemplateContentFormatOption = (
            chat_template_content_format
        )
        self.trust_request_chat_template = trust_request_chat_template
        self.enable_auto_tools = enable_auto_tools
        self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
        self.tool_parser: Callable[[TokenizerLike], ToolParser] | None = (
            ParserManager.get_tool_parser(
                tool_parser_name=tool_parser,
                enable_auto_tools=enable_auto_tools,
                model_name=model_config.model,
            )
        )
        self.default_chat_template_kwargs: dict[str, Any] = (
            default_chat_template_kwargs or {}
        )
        self.log_error_stack = log_error_stack
        self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
        self.supports_browsing = False
        self.supports_code_interpreter = False

    async def render_chat_request(
        self,
        request: ChatCompletionRequest,
    ) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse:
        """Copied from OpenAIServingChat.render_chat_request.

        Differences: engine_client.errored check removed (no engine client).
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            logger.error("Error with model %s", error_check_ret)
            return error_check_ret

        tokenizer = self.renderer.tokenizer

        tool_parser = self.tool_parser

        if is_mistral_tokenizer(tokenizer):
            # because of issues with pydantic we need to potentially
            # re-serialize the tool_calls field of the request
            # for more info: see comment in `maybe_serialize_tool_calls`
            _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]
            _mt.truncate_tool_call_ids(request)  # type: ignore[arg-type]
            _mt.validate_request_params(request)

        # Check if tool parsing is unavailable (common condition)
        tool_parsing_unavailable = (
            tool_parser is None
            and not is_mistral_tokenizer(tokenizer)
            and not self.use_harmony
        )

        # Validate tool_choice when tool parsing is required but unavailable
        if tool_parsing_unavailable and request.tool_choice not in (
            None,
            "none",
        ):
            if request.tool_choice == "auto" and not self.enable_auto_tools:
                # for hf tokenizers, "auto" tools requires
                # --enable-auto-tool-choice and --tool-call-parser
                return self.create_error_response(
                    '"auto" tool choice requires '
                    "--enable-auto-tool-choice and --tool-call-parser to be set"
                )
            elif request.tool_choice != "auto":
                # "required" or named tool requires tool parser
                return self.create_error_response(
                    f'tool_choice="{request.tool_choice}" requires '
                    "--tool-call-parser to be set"
                )

        if request.tools is None or (
            request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
        ):
            tool_dicts = None
        else:
            tool_dicts = [tool.model_dump() for tool in request.tools]

        if not self.use_harmony:
            # Common case.
            error_check_ret = self._validate_chat_template(
                request_chat_template=request.chat_template,
                chat_template_kwargs=request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
            if error_check_ret is not None:
                return error_check_ret

            conversation, engine_prompts = await self._preprocess_chat(
                request,
                request.messages,
                default_template=self.chat_template,
                default_template_content_format=self.chat_template_content_format,
                default_template_kwargs=self.default_chat_template_kwargs,
                tool_dicts=tool_dicts,
                tool_parser=tool_parser,
            )
        else:
            # For GPT-OSS.
            should_include_tools = tool_dicts is not None
            conversation, engine_prompts = self._make_request_with_harmony(
                request, should_include_tools
            )

        return conversation, engine_prompts

    async def render_completion_request(
        self,
        request: CompletionRequest,
    ) -> list[ProcessorInputs] | ErrorResponse:
        """Copied from OpenAIServingCompletion.render_completion_request.

        Differences: engine_client.errored check removed (no engine client).
        """
        error_check_ret = await self._check_model(request)
        if error_check_ret is not None:
            return error_check_ret

        # Return error for unsupported features.
        if request.suffix is not None:
            return self.create_error_response("suffix is not currently supported")

        if request.echo and request.prompt_embeds is not None:
            return self.create_error_response("Echo is unsupported with prompt embeds.")

        if request.prompt_logprobs is not None and request.prompt_embeds is not None:
            return self.create_error_response(
                "prompt_logprobs is not compatible with prompt embeds."
            )

        engine_prompts = await self._preprocess_completion(
            request,
            prompt_input=request.prompt,
            prompt_embeds=request.prompt_embeds,
        )

        return engine_prompts

    def _make_request_with_harmony(
        self,
        request: ChatCompletionRequest,
        should_include_tools: bool = True,
    ):
        """Copied from OpenAIServingChat._make_request_with_harmony."""
        messages: list[OpenAIMessage] = []

        # because of issues with pydantic we need to potentially
        # re-serialize the tool_calls field of the request
        # for more info: see comment in `maybe_serialize_tool_calls`
        _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]

        # Add system message.
        # NOTE: In Chat Completion API, browsing is enabled by default
        # if the model supports it. TODO: Support browsing.
        assert not self.supports_browsing
        assert not self.supports_code_interpreter
        assert request.reasoning_effort != "none", (
            "Harmony does not support reasoning_effort='none'"
        )
        sys_msg = get_system_message(
            reasoning_effort=request.reasoning_effort,
            browser_description=None,
            python_description=None,
            with_custom_tools=should_include_tools,
        )
        messages.append(sys_msg)

        # Add developer message.
        if request.tools:
            dev_msg = get_developer_message(
                tools=request.tools if should_include_tools else None  # type: ignore[arg-type]
            )
            messages.append(dev_msg)

        # Add user message.
        messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))

        # Render prompt token ids.
        prompt_token_ids = render_for_completion(messages)
        engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)

        # Add cache_salt if provided in the request
        if request.cache_salt is not None:
            engine_prompt["cache_salt"] = request.cache_salt

        return messages, [engine_prompt]

    async def show_available_models(self) -> ModelList:
        """Returns the models served by this render server."""
        max_model_len = self.model_config.max_model_len
        return ModelList(
            data=[
                ModelCard(
                    id=name,
                    max_model_len=max_model_len,
                    root=self.model_config.model,
                    permission=[ModelPermission()],
                )
                for name in self.served_model_names
            ]
        )

    def create_error_response(
        self,
        message: str | Exception,
        err_type: str = "BadRequestError",
        status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
        param: str | None = None,
    ) -> ErrorResponse:
        return create_error_response(message, err_type, status_code, param)

    def _is_model_supported(self, model_name: str) -> bool:
        """Simplified from OpenAIServing._is_model_supported (no LoRA support)."""
        return model_name in self.served_model_names

    async def _check_model(
        self,
        request: Any,
    ) -> ErrorResponse | None:
        """Simplified from OpenAIServing._check_model (no LoRA support)."""
        if self._is_model_supported(request.model):
            return None
        return self.create_error_response(
            message=f"The model `{request.model}` does not exist.",
            err_type="NotFoundError",
            status_code=HTTPStatus.NOT_FOUND,
            param="model",
        )

    def _validate_chat_template(
        self,
        request_chat_template: str | None,
        chat_template_kwargs: dict[str, Any] | None,
        trust_request_chat_template: bool,
    ) -> ErrorResponse | None:
        """Copied from OpenAIServing._validate_chat_template."""
        if not trust_request_chat_template and (
            request_chat_template is not None
            or (
                chat_template_kwargs
                and chat_template_kwargs.get("chat_template") is not None
            )
        ):
            return self.create_error_response(
                "Chat template is passed with request, but "
                "--trust-request-chat-template is not set. "
                "Refused request with untrusted chat template."
            )
        return None

    async def _preprocess_completion(
        self,
        request: Any,
        prompt_input: str | list[str] | list[int] | list[list[int]] | None,
        prompt_embeds: bytes | list[bytes] | None,
    ) -> list[ProcessorInputs]:
        """Copied from OpenAIServing._preprocess_completion."""
        prompts = list[SingletonPrompt | bytes]()
        if prompt_embeds is not None:  # embeds take higher priority
            prompts.extend(prompt_to_seq(prompt_embeds))
        if prompt_input is not None:
            prompts.extend(prompt_to_seq(prompt_input))
        return await self._preprocess_cmpl(request, prompts)

    async def _preprocess_cmpl(
        self,
        request: Any,
        prompts: Sequence[PromptType | bytes],
    ) -> list[ProcessorInputs]:
        """Copied from OpenAIServing._preprocess_cmpl."""
        renderer = self.renderer
        model_config = self.model_config

        parsed_prompts = [
            (
                prompt
                if isinstance(prompt, bytes)
                else parse_model_prompt(model_config, prompt)
            )
            for prompt in prompts
        ]
        tok_params = request.build_tok_params(model_config)

        return await renderer.render_cmpl_async(
            parsed_prompts,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )

    async def _preprocess_chat(
        self,
        request: Any,
        messages: list[Any],
        default_template: str | None,
        default_template_content_format: ChatTemplateContentFormatOption,
        default_template_kwargs: dict[str, Any] | None,
        tool_dicts: list[dict[str, Any]] | None = None,
        tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
    ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
        """Copied from OpenAIServing._preprocess_chat.

        Differences: isinstance check is ChatCompletionRequest-only
        (ResponsesRequest not supported here); TODO comment dropped accordingly.
        """
        renderer = self.renderer

        default_template_kwargs = merge_kwargs(
            default_template_kwargs,
            dict(
                tools=tool_dicts,
                tokenize=is_mistral_tokenizer(renderer.tokenizer),
            ),
        )

        tok_params = request.build_tok_params(self.model_config)
        chat_params = request.build_chat_params(
            default_template, default_template_content_format
        ).with_defaults(default_template_kwargs)

        (conversation,), (engine_prompt,) = await renderer.render_chat_async(
            [messages],
            chat_params,
            tok_params,
            prompt_extras={
                k: v
                for k in ("mm_processor_kwargs", "cache_salt")
                if (v := getattr(request, k, None)) is not None
            },
        )

        # tool parsing is done only if a tool_parser has been set and if
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
        if tool_parser is not None:
            tool_choice = getattr(request, "tool_choice", "none")
            if tool_choice != "none":
                if not isinstance(request, ChatCompletionRequest):
                    msg = (
                        "Tool usage is only supported "
                        " for ChatCompletionRequest, but got "
                        f"{type(request).__name__}"
                    )
                    raise NotImplementedError(msg)
                tokenizer = renderer.get_tokenizer()
                request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore[arg-type]

        return conversation, [engine_prompt]

_check_model async

_check_model(request: Any) -> ErrorResponse | None

Simplified from OpenAIServing._check_model (no LoRA support).

Source code in vllm/entrypoints/serve/render/serving.py
async def _check_model(
    self,
    request: Any,
) -> ErrorResponse | None:
    """Simplified from OpenAIServing._check_model (no LoRA support)."""
    if self._is_model_supported(request.model):
        return None
    return self.create_error_response(
        message=f"The model `{request.model}` does not exist.",
        err_type="NotFoundError",
        status_code=HTTPStatus.NOT_FOUND,
        param="model",
    )

_is_model_supported

_is_model_supported(model_name: str) -> bool

Simplified from OpenAIServing._is_model_supported (no LoRA support).

Source code in vllm/entrypoints/serve/render/serving.py
def _is_model_supported(self, model_name: str) -> bool:
    """Simplified from OpenAIServing._is_model_supported (no LoRA support)."""
    return model_name in self.served_model_names

_make_request_with_harmony

_make_request_with_harmony(
    request: ChatCompletionRequest,
    should_include_tools: bool = True,
)

Copied from OpenAIServingChat._make_request_with_harmony.

Source code in vllm/entrypoints/serve/render/serving.py
def _make_request_with_harmony(
    self,
    request: ChatCompletionRequest,
    should_include_tools: bool = True,
):
    """Copied from OpenAIServingChat._make_request_with_harmony."""
    messages: list[OpenAIMessage] = []

    # because of issues with pydantic we need to potentially
    # re-serialize the tool_calls field of the request
    # for more info: see comment in `maybe_serialize_tool_calls`
    _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]

    # Add system message.
    # NOTE: In Chat Completion API, browsing is enabled by default
    # if the model supports it. TODO: Support browsing.
    assert not self.supports_browsing
    assert not self.supports_code_interpreter
    assert request.reasoning_effort != "none", (
        "Harmony does not support reasoning_effort='none'"
    )
    sys_msg = get_system_message(
        reasoning_effort=request.reasoning_effort,
        browser_description=None,
        python_description=None,
        with_custom_tools=should_include_tools,
    )
    messages.append(sys_msg)

    # Add developer message.
    if request.tools:
        dev_msg = get_developer_message(
            tools=request.tools if should_include_tools else None  # type: ignore[arg-type]
        )
        messages.append(dev_msg)

    # Add user message.
    messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))

    # Render prompt token ids.
    prompt_token_ids = render_for_completion(messages)
    engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)

    # Add cache_salt if provided in the request
    if request.cache_salt is not None:
        engine_prompt["cache_salt"] = request.cache_salt

    return messages, [engine_prompt]

_preprocess_chat async

_preprocess_chat(
    request: Any,
    messages: list[Any],
    default_template: str | None,
    default_template_content_format: ChatTemplateContentFormatOption,
    default_template_kwargs: dict[str, Any] | None,
    tool_dicts: list[dict[str, Any]] | None = None,
    tool_parser: Callable[[TokenizerLike], ToolParser]
    | None = None,
) -> tuple[
    list[ConversationMessage], list[ProcessorInputs]
]

Copied from OpenAIServing._preprocess_chat.

Differences: isinstance check is ChatCompletionRequest-only (ResponsesRequest not supported here); TODO comment dropped accordingly.

Source code in vllm/entrypoints/serve/render/serving.py
async def _preprocess_chat(
    self,
    request: Any,
    messages: list[Any],
    default_template: str | None,
    default_template_content_format: ChatTemplateContentFormatOption,
    default_template_kwargs: dict[str, Any] | None,
    tool_dicts: list[dict[str, Any]] | None = None,
    tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
    """Copied from OpenAIServing._preprocess_chat.

    Differences: isinstance check is ChatCompletionRequest-only
    (ResponsesRequest not supported here); TODO comment dropped accordingly.
    """
    renderer = self.renderer

    default_template_kwargs = merge_kwargs(
        default_template_kwargs,
        dict(
            tools=tool_dicts,
            tokenize=is_mistral_tokenizer(renderer.tokenizer),
        ),
    )

    tok_params = request.build_tok_params(self.model_config)
    chat_params = request.build_chat_params(
        default_template, default_template_content_format
    ).with_defaults(default_template_kwargs)

    (conversation,), (engine_prompt,) = await renderer.render_chat_async(
        [messages],
        chat_params,
        tok_params,
        prompt_extras={
            k: v
            for k in ("mm_processor_kwargs", "cache_salt")
            if (v := getattr(request, k, None)) is not None
        },
    )

    # tool parsing is done only if a tool_parser has been set and if
    # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
    # is set, we want to prevent parsing a tool_call hallucinated by the LLM
    if tool_parser is not None:
        tool_choice = getattr(request, "tool_choice", "none")
        if tool_choice != "none":
            if not isinstance(request, ChatCompletionRequest):
                msg = (
                    "Tool usage is only supported "
                    " for ChatCompletionRequest, but got "
                    f"{type(request).__name__}"
                )
                raise NotImplementedError(msg)
            tokenizer = renderer.get_tokenizer()
            request = tool_parser(tokenizer).adjust_request(request=request)  # type: ignore[arg-type]

    return conversation, [engine_prompt]

_preprocess_cmpl async

_preprocess_cmpl(
    request: Any, prompts: Sequence[PromptType | bytes]
) -> list[ProcessorInputs]

Copied from OpenAIServing._preprocess_cmpl.

Source code in vllm/entrypoints/serve/render/serving.py
async def _preprocess_cmpl(
    self,
    request: Any,
    prompts: Sequence[PromptType | bytes],
) -> list[ProcessorInputs]:
    """Copied from OpenAIServing._preprocess_cmpl."""
    renderer = self.renderer
    model_config = self.model_config

    parsed_prompts = [
        (
            prompt
            if isinstance(prompt, bytes)
            else parse_model_prompt(model_config, prompt)
        )
        for prompt in prompts
    ]
    tok_params = request.build_tok_params(model_config)

    return await renderer.render_cmpl_async(
        parsed_prompts,
        tok_params,
        prompt_extras={
            k: v
            for k in ("mm_processor_kwargs", "cache_salt")
            if (v := getattr(request, k, None)) is not None
        },
    )

_preprocess_completion async

_preprocess_completion(
    request: Any,
    prompt_input: str
    | list[str]
    | list[int]
    | list[list[int]]
    | None,
    prompt_embeds: bytes | list[bytes] | None,
) -> list[ProcessorInputs]

Copied from OpenAIServing._preprocess_completion.

Source code in vllm/entrypoints/serve/render/serving.py
async def _preprocess_completion(
    self,
    request: Any,
    prompt_input: str | list[str] | list[int] | list[list[int]] | None,
    prompt_embeds: bytes | list[bytes] | None,
) -> list[ProcessorInputs]:
    """Copied from OpenAIServing._preprocess_completion."""
    prompts = list[SingletonPrompt | bytes]()
    if prompt_embeds is not None:  # embeds take higher priority
        prompts.extend(prompt_to_seq(prompt_embeds))
    if prompt_input is not None:
        prompts.extend(prompt_to_seq(prompt_input))
    return await self._preprocess_cmpl(request, prompts)

_validate_chat_template

_validate_chat_template(
    request_chat_template: str | None,
    chat_template_kwargs: dict[str, Any] | None,
    trust_request_chat_template: bool,
) -> ErrorResponse | None

Copied from OpenAIServing._validate_chat_template.

Source code in vllm/entrypoints/serve/render/serving.py
def _validate_chat_template(
    self,
    request_chat_template: str | None,
    chat_template_kwargs: dict[str, Any] | None,
    trust_request_chat_template: bool,
) -> ErrorResponse | None:
    """Copied from OpenAIServing._validate_chat_template."""
    if not trust_request_chat_template and (
        request_chat_template is not None
        or (
            chat_template_kwargs
            and chat_template_kwargs.get("chat_template") is not None
        )
    ):
        return self.create_error_response(
            "Chat template is passed with request, but "
            "--trust-request-chat-template is not set. "
            "Refused request with untrusted chat template."
        )
    return None

render_chat_request async

render_chat_request(
    request: ChatCompletionRequest,
) -> (
    tuple[list[ConversationMessage], list[ProcessorInputs]]
    | ErrorResponse
)

Copied from OpenAIServingChat.render_chat_request.

Differences: engine_client.errored check removed (no engine client).

Source code in vllm/entrypoints/serve/render/serving.py
async def render_chat_request(
    self,
    request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse:
    """Copied from OpenAIServingChat.render_chat_request.

    Differences: engine_client.errored check removed (no engine client).
    """
    error_check_ret = await self._check_model(request)
    if error_check_ret is not None:
        logger.error("Error with model %s", error_check_ret)
        return error_check_ret

    tokenizer = self.renderer.tokenizer

    tool_parser = self.tool_parser

    if is_mistral_tokenizer(tokenizer):
        # because of issues with pydantic we need to potentially
        # re-serialize the tool_calls field of the request
        # for more info: see comment in `maybe_serialize_tool_calls`
        _mt.maybe_serialize_tool_calls(request)  # type: ignore[arg-type]
        _mt.truncate_tool_call_ids(request)  # type: ignore[arg-type]
        _mt.validate_request_params(request)

    # Check if tool parsing is unavailable (common condition)
    tool_parsing_unavailable = (
        tool_parser is None
        and not is_mistral_tokenizer(tokenizer)
        and not self.use_harmony
    )

    # Validate tool_choice when tool parsing is required but unavailable
    if tool_parsing_unavailable and request.tool_choice not in (
        None,
        "none",
    ):
        if request.tool_choice == "auto" and not self.enable_auto_tools:
            # for hf tokenizers, "auto" tools requires
            # --enable-auto-tool-choice and --tool-call-parser
            return self.create_error_response(
                '"auto" tool choice requires '
                "--enable-auto-tool-choice and --tool-call-parser to be set"
            )
        elif request.tool_choice != "auto":
            # "required" or named tool requires tool parser
            return self.create_error_response(
                f'tool_choice="{request.tool_choice}" requires '
                "--tool-call-parser to be set"
            )

    if request.tools is None or (
        request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
    ):
        tool_dicts = None
    else:
        tool_dicts = [tool.model_dump() for tool in request.tools]

    if not self.use_harmony:
        # Common case.
        error_check_ret = self._validate_chat_template(
            request_chat_template=request.chat_template,
            chat_template_kwargs=request.chat_template_kwargs,
            trust_request_chat_template=self.trust_request_chat_template,
        )
        if error_check_ret is not None:
            return error_check_ret

        conversation, engine_prompts = await self._preprocess_chat(
            request,
            request.messages,
            default_template=self.chat_template,
            default_template_content_format=self.chat_template_content_format,
            default_template_kwargs=self.default_chat_template_kwargs,
            tool_dicts=tool_dicts,
            tool_parser=tool_parser,
        )
    else:
        # For GPT-OSS.
        should_include_tools = tool_dicts is not None
        conversation, engine_prompts = self._make_request_with_harmony(
            request, should_include_tools
        )

    return conversation, engine_prompts

render_completion_request async

render_completion_request(
    request: CompletionRequest,
) -> list[ProcessorInputs] | ErrorResponse

Copied from OpenAIServingCompletion.render_completion_request.

Differences: engine_client.errored check removed (no engine client).

Source code in vllm/entrypoints/serve/render/serving.py
async def render_completion_request(
    self,
    request: CompletionRequest,
) -> list[ProcessorInputs] | ErrorResponse:
    """Copied from OpenAIServingCompletion.render_completion_request.

    Differences: engine_client.errored check removed (no engine client).
    """
    error_check_ret = await self._check_model(request)
    if error_check_ret is not None:
        return error_check_ret

    # Return error for unsupported features.
    if request.suffix is not None:
        return self.create_error_response("suffix is not currently supported")

    if request.echo and request.prompt_embeds is not None:
        return self.create_error_response("Echo is unsupported with prompt embeds.")

    if request.prompt_logprobs is not None and request.prompt_embeds is not None:
        return self.create_error_response(
            "prompt_logprobs is not compatible with prompt embeds."
        )

    engine_prompts = await self._preprocess_completion(
        request,
        prompt_input=request.prompt,
        prompt_embeds=request.prompt_embeds,
    )

    return engine_prompts

show_available_models async

show_available_models() -> ModelList

Returns the models served by this render server.

Source code in vllm/entrypoints/serve/render/serving.py
async def show_available_models(self) -> ModelList:
    """Returns the models served by this render server."""
    max_model_len = self.model_config.max_model_len
    return ModelList(
        data=[
            ModelCard(
                id=name,
                max_model_len=max_model_len,
                root=self.model_config.model,
                permission=[ModelPermission()],
            )
            for name in self.served_model_names
        ]
    )