Skip to content

vllm.v1.worker.gpu.cudagraph_utils

BatchExecutionDescriptor dataclass

Describes the shape of the batch and CG mode to run; this is used to make shape matches between the capture and runtime.

Source code in vllm/v1/worker/gpu/cudagraph_utils.py
@dataclass(frozen=True)
class BatchExecutionDescriptor:
    """Describes the shape of the batch and CG mode to run; this is used to make shape
    matches between the capture and runtime."""

    cg_mode: CUDAGraphMode
    num_tokens: int
    num_reqs: int | None  # None means no request padding is needed (PIECEWISE graphs)
    uniform_token_count: int | None = None

CudaGraphManager

Source code in vllm/v1/worker/gpu/cudagraph_utils.py
class CudaGraphManager:
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        cudagraph_mode: CUDAGraphMode,
        decode_query_len: int,
    ):
        self.vllm_config = vllm_config
        self.device = device
        self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        self.compilation_config = vllm_config.compilation_config
        assert self.compilation_config is not None
        self.cudagraph_mode = cudagraph_mode
        self.decode_query_len = decode_query_len
        self.dp_size = vllm_config.parallel_config.data_parallel_size

        self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
        self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None

        self._graphs_captured = False
        self._candidates: list[list[BatchExecutionDescriptor]] = []
        self._capture_descs: dict[CUDAGraphMode, list[BatchExecutionDescriptor]] = {}
        self._init_candidates()

    def _init_candidates(self) -> None:
        """Build priority-ordered candidate lists for each token count."""
        capture_sizes = self.compilation_config.cudagraph_capture_sizes
        if not (self.cudagraph_mode and capture_sizes):
            return

        capture_sizes = sorted(capture_sizes)
        max_decode_tokens = self.max_num_reqs * self.decode_query_len
        decode_mode = self.cudagraph_mode.decode_mode()
        mixed_mode = self.cudagraph_mode.mixed_mode()
        separate_decode_routine = self.cudagraph_mode.separate_routine()

        descs_by_token_count = defaultdict(list)
        descs_by_mode = defaultdict(list)

        for num_tokens in capture_sizes:
            # Capture uniform decode specfifc graphs if required
            #  (i.e. separate decode routine)
            if (
                separate_decode_routine
                and decode_mode
                and self.decode_query_len <= num_tokens <= max_decode_tokens
            ):
                desc = BatchExecutionDescriptor(
                    cg_mode=decode_mode,
                    num_tokens=num_tokens,
                    num_reqs=num_tokens // self.decode_query_len,
                    uniform_token_count=self.decode_query_len,
                )
                descs_by_mode[decode_mode].append(desc)
                descs_by_token_count[num_tokens].append(desc)

            if mixed_mode:
                # for PIECEWISE graphs there is no limit on requests when replaying
                # i.e. no request padding is needed
                # so we leave it as None
                num_reqs = (
                    min(num_tokens, self.max_num_reqs)
                    if mixed_mode == CUDAGraphMode.FULL
                    else None
                )
                desc = BatchExecutionDescriptor(
                    cg_mode=mixed_mode,
                    num_tokens=num_tokens,
                    num_reqs=num_reqs,
                )
                descs_by_mode[mixed_mode].append(desc)
                descs_by_token_count[num_tokens].append(desc)

        if not descs_by_token_count:
            return

        sorted_padded = sorted(descs_by_token_count.keys())
        self._candidates = [[] for _ in range(sorted_padded[-1] + 1)]

        current_range_start = 0
        for cg_size in sorted_padded:
            for i in range(current_range_start, cg_size + 1):
                self._candidates[i] = descs_by_token_count[cg_size]
            current_range_start = cg_size + 1

        for mode, descs in descs_by_mode.items():
            descs.sort(key=lambda d: d.num_tokens, reverse=True)
            self._capture_descs[mode] = descs

    def needs_capture(self) -> bool:
        return len(self._capture_descs) > 0

    @torch.inference_mode()
    def capture(
        self,
        create_forward_fn: Callable[
            [BatchExecutionDescriptor], Callable[[CUDAGraphMode], None]
        ],
        progress_bar_desc: str = "Capturing CUDA graphs",
    ) -> None:
        """Capture CUDA graphs.

        Args:
            create_forward_fn: Factory that prepares inputs (OUTSIDE graph) and
                returns a function that runs forward with a given CUDAGraphMode.
        """
        with graph_capture(device=self.device):
            # Capture in order: PIECEWISE first, then FULL. PIECEWISE has larger
            # activations so FULL activations should fit in already allocated
            # buffers in the graph pool.
            for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
                if mode not in self._capture_descs:
                    continue

                descs = self._capture_descs[mode]
                if is_global_first_rank():
                    descs = tqdm(descs, desc=f"{progress_bar_desc} ({mode.name})")
                for desc in descs:
                    # Prepare inputs and get forward function
                    forward_fn = create_forward_fn(desc)

                    # Warmup
                    forward_fn(CUDAGraphMode.NONE)

                    # Capture
                    logger.debug(
                        "CG Capture: mode=%s, batch_desc=%s", desc.cg_mode.name, desc
                    )
                    if desc.cg_mode == CUDAGraphMode.PIECEWISE:
                        forward_fn(CUDAGraphMode.PIECEWISE)
                    else:
                        assert desc not in self.graphs, (
                            f"Graph already captured for {desc}"
                        )
                        graph = torch.cuda.CUDAGraph()
                        # Sync offloader's copy stream before capture.
                        # Ensure any pre-capture prefetches from offloader are complete.
                        get_offloader().sync_prev_onload()
                        with torch.cuda.graph(graph, self.pool):
                            forward_fn(CUDAGraphMode.NONE)
                            # Join offloader's copy stream after forward to avoid
                            # unjoined stream error. The last layer's start_prefetch
                            # forks copy_stream, but wait_prefetch only happens in
                            # the next forward pass.
                            get_offloader().join_after_forward()
                        self.graphs[desc] = graph
        self._graphs_captured = True

    def dispatch(
        self,
        num_reqs: int,
        num_tokens: int,
        uniform_token_count: int | None,
    ) -> BatchExecutionDescriptor:
        """Find matching cudagraph descriptor from priority-ordered candidates."""
        if self._graphs_captured and 0 < num_tokens < len(self._candidates):
            for desc in self._candidates[num_tokens]:
                if _is_compatible(desc, num_reqs, num_tokens, uniform_token_count):
                    return desc
        return BatchExecutionDescriptor(
            cg_mode=CUDAGraphMode.NONE, num_tokens=num_tokens, num_reqs=num_reqs
        )

    def run_fullgraph(self, desc: BatchExecutionDescriptor):
        """Replay a captured FULL cudagraph."""
        assert desc.cg_mode == CUDAGraphMode.FULL, (
            f"Expected FULL mode, got {desc.cg_mode}"
        )
        assert desc in self.graphs, f"No cudagraph for {desc}"
        # Sync offloader before replay - needed when transitioning from
        # eager/piecewise to full cudagraph (e.g., prefill → decode).
        # The previous eager iteration's start_prefetch may have queued
        # H2D copies on copy_stream that the graph's captured events
        # cannot see. Without this, replay could overwrite static buffers
        # while those copies are still in flight.
        get_offloader().sync_prev_onload()
        self.graphs[desc].replay()

_init_candidates

_init_candidates() -> None

Build priority-ordered candidate lists for each token count.

Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def _init_candidates(self) -> None:
    """Build priority-ordered candidate lists for each token count."""
    capture_sizes = self.compilation_config.cudagraph_capture_sizes
    if not (self.cudagraph_mode and capture_sizes):
        return

    capture_sizes = sorted(capture_sizes)
    max_decode_tokens = self.max_num_reqs * self.decode_query_len
    decode_mode = self.cudagraph_mode.decode_mode()
    mixed_mode = self.cudagraph_mode.mixed_mode()
    separate_decode_routine = self.cudagraph_mode.separate_routine()

    descs_by_token_count = defaultdict(list)
    descs_by_mode = defaultdict(list)

    for num_tokens in capture_sizes:
        # Capture uniform decode specfifc graphs if required
        #  (i.e. separate decode routine)
        if (
            separate_decode_routine
            and decode_mode
            and self.decode_query_len <= num_tokens <= max_decode_tokens
        ):
            desc = BatchExecutionDescriptor(
                cg_mode=decode_mode,
                num_tokens=num_tokens,
                num_reqs=num_tokens // self.decode_query_len,
                uniform_token_count=self.decode_query_len,
            )
            descs_by_mode[decode_mode].append(desc)
            descs_by_token_count[num_tokens].append(desc)

        if mixed_mode:
            # for PIECEWISE graphs there is no limit on requests when replaying
            # i.e. no request padding is needed
            # so we leave it as None
            num_reqs = (
                min(num_tokens, self.max_num_reqs)
                if mixed_mode == CUDAGraphMode.FULL
                else None
            )
            desc = BatchExecutionDescriptor(
                cg_mode=mixed_mode,
                num_tokens=num_tokens,
                num_reqs=num_reqs,
            )
            descs_by_mode[mixed_mode].append(desc)
            descs_by_token_count[num_tokens].append(desc)

    if not descs_by_token_count:
        return

    sorted_padded = sorted(descs_by_token_count.keys())
    self._candidates = [[] for _ in range(sorted_padded[-1] + 1)]

    current_range_start = 0
    for cg_size in sorted_padded:
        for i in range(current_range_start, cg_size + 1):
            self._candidates[i] = descs_by_token_count[cg_size]
        current_range_start = cg_size + 1

    for mode, descs in descs_by_mode.items():
        descs.sort(key=lambda d: d.num_tokens, reverse=True)
        self._capture_descs[mode] = descs

capture

capture(
    create_forward_fn: Callable[
        [BatchExecutionDescriptor],
        Callable[[CUDAGraphMode], None],
    ],
    progress_bar_desc: str = "Capturing CUDA graphs",
) -> None

Capture CUDA graphs.

Parameters:

Name Type Description Default
create_forward_fn Callable[[BatchExecutionDescriptor], Callable[[CUDAGraphMode], None]]

Factory that prepares inputs (OUTSIDE graph) and returns a function that runs forward with a given CUDAGraphMode.

required
Source code in vllm/v1/worker/gpu/cudagraph_utils.py
@torch.inference_mode()
def capture(
    self,
    create_forward_fn: Callable[
        [BatchExecutionDescriptor], Callable[[CUDAGraphMode], None]
    ],
    progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
    """Capture CUDA graphs.

    Args:
        create_forward_fn: Factory that prepares inputs (OUTSIDE graph) and
            returns a function that runs forward with a given CUDAGraphMode.
    """
    with graph_capture(device=self.device):
        # Capture in order: PIECEWISE first, then FULL. PIECEWISE has larger
        # activations so FULL activations should fit in already allocated
        # buffers in the graph pool.
        for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
            if mode not in self._capture_descs:
                continue

            descs = self._capture_descs[mode]
            if is_global_first_rank():
                descs = tqdm(descs, desc=f"{progress_bar_desc} ({mode.name})")
            for desc in descs:
                # Prepare inputs and get forward function
                forward_fn = create_forward_fn(desc)

                # Warmup
                forward_fn(CUDAGraphMode.NONE)

                # Capture
                logger.debug(
                    "CG Capture: mode=%s, batch_desc=%s", desc.cg_mode.name, desc
                )
                if desc.cg_mode == CUDAGraphMode.PIECEWISE:
                    forward_fn(CUDAGraphMode.PIECEWISE)
                else:
                    assert desc not in self.graphs, (
                        f"Graph already captured for {desc}"
                    )
                    graph = torch.cuda.CUDAGraph()
                    # Sync offloader's copy stream before capture.
                    # Ensure any pre-capture prefetches from offloader are complete.
                    get_offloader().sync_prev_onload()
                    with torch.cuda.graph(graph, self.pool):
                        forward_fn(CUDAGraphMode.NONE)
                        # Join offloader's copy stream after forward to avoid
                        # unjoined stream error. The last layer's start_prefetch
                        # forks copy_stream, but wait_prefetch only happens in
                        # the next forward pass.
                        get_offloader().join_after_forward()
                    self.graphs[desc] = graph
    self._graphs_captured = True

dispatch

dispatch(
    num_reqs: int,
    num_tokens: int,
    uniform_token_count: int | None,
) -> BatchExecutionDescriptor

Find matching cudagraph descriptor from priority-ordered candidates.

Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def dispatch(
    self,
    num_reqs: int,
    num_tokens: int,
    uniform_token_count: int | None,
) -> BatchExecutionDescriptor:
    """Find matching cudagraph descriptor from priority-ordered candidates."""
    if self._graphs_captured and 0 < num_tokens < len(self._candidates):
        for desc in self._candidates[num_tokens]:
            if _is_compatible(desc, num_reqs, num_tokens, uniform_token_count):
                return desc
    return BatchExecutionDescriptor(
        cg_mode=CUDAGraphMode.NONE, num_tokens=num_tokens, num_reqs=num_reqs
    )

run_fullgraph

run_fullgraph(desc: BatchExecutionDescriptor)

Replay a captured FULL cudagraph.

Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def run_fullgraph(self, desc: BatchExecutionDescriptor):
    """Replay a captured FULL cudagraph."""
    assert desc.cg_mode == CUDAGraphMode.FULL, (
        f"Expected FULL mode, got {desc.cg_mode}"
    )
    assert desc in self.graphs, f"No cudagraph for {desc}"
    # Sync offloader before replay - needed when transitioning from
    # eager/piecewise to full cudagraph (e.g., prefill → decode).
    # The previous eager iteration's start_prefetch may have queued
    # H2D copies on copy_stream that the graph's captured events
    # cannot see. Without this, replay could overwrite static buffers
    # while those copies are still in flight.
    get_offloader().sync_prev_onload()
    self.graphs[desc].replay()

ModelCudaGraphManager

Bases: CudaGraphManager

CudaGraphManager with model-specific capture and hidden state management.

Source code in vllm/v1/worker/gpu/cudagraph_utils.py
class ModelCudaGraphManager(CudaGraphManager):
    """CudaGraphManager with model-specific capture and hidden state management."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        cudagraph_mode: CUDAGraphMode,
        decode_query_len: int,
    ):
        super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)
        self.hidden_states: torch.Tensor | None = None
        self.aux_hidden_states: list[torch.Tensor] = []
        self.use_aux_hidden_state_outputs = False

    def capture(
        self,
        model: nn.Module,
        model_state: ModelState,
        input_buffers: InputBuffers,
        block_tables: BlockTables,
        attn_groups: list[list[AttentionGroup]],
        kv_cache_config: KVCacheConfig,
        has_lora: bool = False,
        use_aux_hidden_state_outputs: bool = False,
        progress_bar_desc: str = "Capturing CUDA graphs",
    ) -> None:
        """Capture CUDA graphs for model forward pass."""
        self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs

        def create_forward_fn(
            desc: BatchExecutionDescriptor,
        ) -> Callable[[CUDAGraphMode], None]:
            num_tokens = desc.num_tokens
            num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
            num_tokens_across_dp = (
                torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
                if self.dp_size > 1
                else None
            )
            attn_metadata, slot_mappings = prepare_inputs_to_capture(
                num_reqs,
                num_tokens,
                model_state,
                input_buffers,
                block_tables,
                attn_groups,
                kv_cache_config,
            )

            def forward_fn(cg_mode: CUDAGraphMode) -> None:
                batch_descriptor = (
                    BatchDescriptor(num_tokens=num_tokens)
                    if cg_mode == CUDAGraphMode.PIECEWISE
                    else None
                )
                with set_forward_context(
                    attn_metadata if cg_mode != CUDAGraphMode.PIECEWISE else None,
                    self.vllm_config,
                    num_tokens=num_tokens,
                    cudagraph_runtime_mode=cg_mode,
                    num_tokens_across_dp=num_tokens_across_dp,
                    slot_mapping=slot_mappings,
                    batch_descriptor=batch_descriptor,
                ):
                    model_inputs = {
                        "input_ids": input_buffers.input_ids[:num_tokens],
                        "positions": input_buffers.positions[:num_tokens],
                        **model_state.prepare_dummy_inputs(num_reqs, num_tokens),
                    }
                    model_output = model(**model_inputs)
                    if self.use_aux_hidden_state_outputs:
                        hidden_states, aux_hidden_states = model_output
                    else:
                        hidden_states = model_output
                        aux_hidden_states = []
                    if self.hidden_states is None:
                        self.hidden_states = torch.empty_like(hidden_states)
                    if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
                        self.aux_hidden_states = [
                            torch.empty_like(x) for x in aux_hidden_states
                        ]
                    self.hidden_states[:num_tokens] = hidden_states
                    for i, aux in enumerate(aux_hidden_states):
                        self.aux_hidden_states[i][:num_tokens] = aux

            return forward_fn

        super().capture(create_forward_fn, progress_bar_desc)

    def run_fullgraph(
        self, desc: BatchExecutionDescriptor
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
        """Replay a captured FULL cudagraph and return hidden states."""
        super().run_fullgraph(desc)
        assert self.hidden_states is not None
        hidden_states = self.hidden_states[: desc.num_tokens]
        if not self.use_aux_hidden_state_outputs:
            return hidden_states
        return hidden_states, [x[: desc.num_tokens] for x in self.aux_hidden_states]

capture

capture(
    model: Module,
    model_state: ModelState,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
    attn_groups: list[list[AttentionGroup]],
    kv_cache_config: KVCacheConfig,
    has_lora: bool = False,
    use_aux_hidden_state_outputs: bool = False,
    progress_bar_desc: str = "Capturing CUDA graphs",
) -> None

Capture CUDA graphs for model forward pass.

Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def capture(
    self,
    model: nn.Module,
    model_state: ModelState,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
    attn_groups: list[list[AttentionGroup]],
    kv_cache_config: KVCacheConfig,
    has_lora: bool = False,
    use_aux_hidden_state_outputs: bool = False,
    progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
    """Capture CUDA graphs for model forward pass."""
    self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs

    def create_forward_fn(
        desc: BatchExecutionDescriptor,
    ) -> Callable[[CUDAGraphMode], None]:
        num_tokens = desc.num_tokens
        num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
        num_tokens_across_dp = (
            torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
            if self.dp_size > 1
            else None
        )
        attn_metadata, slot_mappings = prepare_inputs_to_capture(
            num_reqs,
            num_tokens,
            model_state,
            input_buffers,
            block_tables,
            attn_groups,
            kv_cache_config,
        )

        def forward_fn(cg_mode: CUDAGraphMode) -> None:
            batch_descriptor = (
                BatchDescriptor(num_tokens=num_tokens)
                if cg_mode == CUDAGraphMode.PIECEWISE
                else None
            )
            with set_forward_context(
                attn_metadata if cg_mode != CUDAGraphMode.PIECEWISE else None,
                self.vllm_config,
                num_tokens=num_tokens,
                cudagraph_runtime_mode=cg_mode,
                num_tokens_across_dp=num_tokens_across_dp,
                slot_mapping=slot_mappings,
                batch_descriptor=batch_descriptor,
            ):
                model_inputs = {
                    "input_ids": input_buffers.input_ids[:num_tokens],
                    "positions": input_buffers.positions[:num_tokens],
                    **model_state.prepare_dummy_inputs(num_reqs, num_tokens),
                }
                model_output = model(**model_inputs)
                if self.use_aux_hidden_state_outputs:
                    hidden_states, aux_hidden_states = model_output
                else:
                    hidden_states = model_output
                    aux_hidden_states = []
                if self.hidden_states is None:
                    self.hidden_states = torch.empty_like(hidden_states)
                if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
                    self.aux_hidden_states = [
                        torch.empty_like(x) for x in aux_hidden_states
                    ]
                self.hidden_states[:num_tokens] = hidden_states
                for i, aux in enumerate(aux_hidden_states):
                    self.aux_hidden_states[i][:num_tokens] = aux

        return forward_fn

    super().capture(create_forward_fn, progress_bar_desc)

run_fullgraph

run_fullgraph(
    desc: BatchExecutionDescriptor,
) -> Tensor | tuple[Tensor, list[Tensor]]

Replay a captured FULL cudagraph and return hidden states.

Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def run_fullgraph(
    self, desc: BatchExecutionDescriptor
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
    """Replay a captured FULL cudagraph and return hidden states."""
    super().run_fullgraph(desc)
    assert self.hidden_states is not None
    hidden_states = self.hidden_states[: desc.num_tokens]
    if not self.use_aux_hidden_state_outputs:
        return hidden_states
    return hidden_states, [x[: desc.num_tokens] for x in self.aux_hidden_states]

get_uniform_token_count

get_uniform_token_count(
    num_reqs: int, num_tokens: int, max_query_len: int
) -> int | None

Return the uniform token count if batch is uniform, else None. A batch is uniform if all requests have the same number of tokens.

Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def get_uniform_token_count(
    num_reqs: int,
    num_tokens: int,
    max_query_len: int,
) -> int | None:
    """
    Return the uniform token count if batch is uniform, else None.
    A batch is uniform if all requests have the same number of tokens.
    """
    if (max_query_len == num_tokens // num_reqs) and (
        num_tokens == max_query_len * num_reqs
    ):
        return max_query_len
    return None