Skip to content

vllm.v1.attention.backends.flashinfer

Attention layer with FlashInfer.

FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT module-attribute

FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = (
    2048 * 1024 * 1024
)

FP4_DTYPE module-attribute

FP4_DTYPE = uint8

FP8_DTYPE module-attribute

FP8_DTYPE = fp8_dtype()

logger module-attribute

logger = init_logger(__name__)

trtllm_gen_workspace_buffer module-attribute

trtllm_gen_workspace_buffer = None

BatchDCPPrefillWrapper

Source code in vllm/v1/attention/backends/flashinfer.py
class BatchDCPPrefillWrapper:
    def __init__(
        self,
        workspace_buffer: torch.Tensor | None = None,
    ):
        self._context = BatchPrefillWithPagedKVCacheWrapper(
            workspace_buffer, get_kv_cache_layout()
        )
        self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(
            workspace_buffer, get_kv_cache_layout()
        )

    def plan(
        self,
        qo_indptr_cpu: torch.Tensor,
        paged_kv_indptr_cpu: torch.Tensor,
        paged_kv_indices: torch.Tensor,
        paged_kv_last_page_len_cpu: torch.Tensor,
        prefill_start: int,
        page_size: int,
        num_qo_heads: int,
        dcp_world_size: int,
        num_kv_heads: int,
        head_dim: int,
        sm_scale: float,
        window_left: int,
        logits_soft_cap: float | None,
        q_data_type: torch.dtype,
        kv_cache_dtype: torch.dtype,
        prefill_fixed_split_size: int,
        disable_split_kv: bool,
    ):
        """Plan the prefill operation with given parameters."""
        self._context.plan(
            qo_indptr_cpu,
            paged_kv_indptr_cpu,
            paged_kv_indices,
            paged_kv_last_page_len_cpu[prefill_start:],
            num_qo_heads * dcp_world_size,
            num_kv_heads,
            head_dim,
            page_size,
            causal=False,  # This is context run
            sm_scale=sm_scale,
            window_left=window_left,
            logits_soft_cap=logits_soft_cap,
            q_data_type=q_data_type,
            kv_data_type=kv_cache_dtype,
            fixed_split_size=prefill_fixed_split_size,
            disable_split_kv=disable_split_kv,
        )
        self._new_tokens.plan(
            qo_indptr=qo_indptr_cpu,
            kv_indptr=qo_indptr_cpu,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim_qk=head_dim,
            head_dim_vo=head_dim,
            causal=True,  # This is newtokens run
            sm_scale=sm_scale,
            window_left=window_left,
            logits_soft_cap=logits_soft_cap,
            q_data_type=q_data_type,
        )

    def run(
        self,
        layer: torch.nn.Module,
        prefill_query: torch.Tensor,
        kv_cache_permute: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        out: torch.Tensor,
    ):
        prefill_query_across_dcp = get_dcp_group().all_gather(
            prefill_query.contiguous(), dim=1
        )
        output_context_tmp, lse_context_tmp = self._context.run(
            prefill_query_across_dcp,
            kv_cache_permute,
            k_scale=layer._k_scale_float,
            v_scale=layer._v_scale_float,
            return_lse=True,
        )
        output_context, lse_context = cp_lse_ag_out_rs(
            output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True
        )
        lse_context = lse_context.transpose(0, 1).contiguous()

        output_query, lse_query = self._new_tokens.run(
            prefill_query,
            key,
            value,
            return_lse=True,
        )
        lse_query = lse_query.transpose(0, 1).contiguous()

        merge_attn_states(
            out,
            output_context,
            lse_context,
            output_query,
            lse_query,
        )
        return out

_context instance-attribute

_context = BatchPrefillWithPagedKVCacheWrapper(
    workspace_buffer, get_kv_cache_layout()
)

_new_tokens instance-attribute

_new_tokens = BatchPrefillWithRaggedKVCacheWrapper(
    workspace_buffer, get_kv_cache_layout()
)

__init__

__init__(workspace_buffer: Tensor | None = None)
Source code in vllm/v1/attention/backends/flashinfer.py
def __init__(
    self,
    workspace_buffer: torch.Tensor | None = None,
):
    self._context = BatchPrefillWithPagedKVCacheWrapper(
        workspace_buffer, get_kv_cache_layout()
    )
    self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(
        workspace_buffer, get_kv_cache_layout()
    )

plan

plan(
    qo_indptr_cpu: Tensor,
    paged_kv_indptr_cpu: Tensor,
    paged_kv_indices: Tensor,
    paged_kv_last_page_len_cpu: Tensor,
    prefill_start: int,
    page_size: int,
    num_qo_heads: int,
    dcp_world_size: int,
    num_kv_heads: int,
    head_dim: int,
    sm_scale: float,
    window_left: int,
    logits_soft_cap: float | None,
    q_data_type: dtype,
    kv_cache_dtype: dtype,
    prefill_fixed_split_size: int,
    disable_split_kv: bool,
)

Plan the prefill operation with given parameters.

Source code in vllm/v1/attention/backends/flashinfer.py
def plan(
    self,
    qo_indptr_cpu: torch.Tensor,
    paged_kv_indptr_cpu: torch.Tensor,
    paged_kv_indices: torch.Tensor,
    paged_kv_last_page_len_cpu: torch.Tensor,
    prefill_start: int,
    page_size: int,
    num_qo_heads: int,
    dcp_world_size: int,
    num_kv_heads: int,
    head_dim: int,
    sm_scale: float,
    window_left: int,
    logits_soft_cap: float | None,
    q_data_type: torch.dtype,
    kv_cache_dtype: torch.dtype,
    prefill_fixed_split_size: int,
    disable_split_kv: bool,
):
    """Plan the prefill operation with given parameters."""
    self._context.plan(
        qo_indptr_cpu,
        paged_kv_indptr_cpu,
        paged_kv_indices,
        paged_kv_last_page_len_cpu[prefill_start:],
        num_qo_heads * dcp_world_size,
        num_kv_heads,
        head_dim,
        page_size,
        causal=False,  # This is context run
        sm_scale=sm_scale,
        window_left=window_left,
        logits_soft_cap=logits_soft_cap,
        q_data_type=q_data_type,
        kv_data_type=kv_cache_dtype,
        fixed_split_size=prefill_fixed_split_size,
        disable_split_kv=disable_split_kv,
    )
    self._new_tokens.plan(
        qo_indptr=qo_indptr_cpu,
        kv_indptr=qo_indptr_cpu,
        num_qo_heads=num_qo_heads,
        num_kv_heads=num_kv_heads,
        head_dim_qk=head_dim,
        head_dim_vo=head_dim,
        causal=True,  # This is newtokens run
        sm_scale=sm_scale,
        window_left=window_left,
        logits_soft_cap=logits_soft_cap,
        q_data_type=q_data_type,
    )

run

run(
    layer: Module,
    prefill_query: Tensor,
    kv_cache_permute: Tensor,
    key: Tensor,
    value: Tensor,
    out: Tensor,
)
Source code in vllm/v1/attention/backends/flashinfer.py
def run(
    self,
    layer: torch.nn.Module,
    prefill_query: torch.Tensor,
    kv_cache_permute: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    out: torch.Tensor,
):
    prefill_query_across_dcp = get_dcp_group().all_gather(
        prefill_query.contiguous(), dim=1
    )
    output_context_tmp, lse_context_tmp = self._context.run(
        prefill_query_across_dcp,
        kv_cache_permute,
        k_scale=layer._k_scale_float,
        v_scale=layer._v_scale_float,
        return_lse=True,
    )
    output_context, lse_context = cp_lse_ag_out_rs(
        output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True
    )
    lse_context = lse_context.transpose(0, 1).contiguous()

    output_query, lse_query = self._new_tokens.run(
        prefill_query,
        key,
        value,
        return_lse=True,
    )
    lse_query = lse_query.transpose(0, 1).contiguous()

    merge_attn_states(
        out,
        output_context,
        lse_context,
        output_query,
        lse_query,
    )
    return out

FlashInferBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/flashinfer.py
class FlashInferBackend(AttentionBackend):
    accept_output_buffer: bool = True
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    # Note: Not sure for all platforms,
    # but on Blackwell, only support a page size of
    # 16, 32, 64
    supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
        "fp8",
        "fp8_e4m3",
        "fp8_e5m2",
    ]

    @staticmethod
    def get_name() -> str:
        return "FLASHINFER"

    @staticmethod
    def get_impl_cls() -> type["FlashInferImpl"]:
        return FlashInferImpl

    @staticmethod
    def get_builder_cls() -> type["FlashInferMetadataBuilder"]:
        return FlashInferMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

    @staticmethod
    def get_kv_cache_stride_order() -> tuple[int, ...]:
        # `stride_order` indicates the permutation that gets us from
        # `get_kv_cache_shape` to the actual memory layout we want.
        cache_layout = get_kv_cache_layout()
        if cache_layout == "NHD":
            stride_order = (0, 1, 2, 3, 4)
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
        return stride_order

    @staticmethod
    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        elif kv_cache_dtype == "fp8_e5m2":
            return torch.float8_e5m2
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
        return [64, 128, 256]

    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability(
            12, 1
        )

    @classmethod
    def supports_sink(cls) -> bool:
        """FlashInfer supports sinks when TRTLLM attention is available (SM100)."""
        from vllm.utils.flashinfer import (
            force_use_trtllm_attention,
            supports_trtllm_attention,
        )

        # Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0)
        if force_use_trtllm_attention() is False:
            return False

        # Check if TRTLLM is supported on this platform
        return supports_trtllm_attention()

    @classmethod
    def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
        from vllm.platforms import current_platform

        capability = current_platform.get_device_capability()
        if capability is not None and capability.major == 10:
            return "HND"
        return None

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

supported_dtypes class-attribute

supported_dtypes: list[dtype] = [float16, bfloat16]

supported_kernel_block_sizes class-attribute

supported_kernel_block_sizes: list[int | MultipleOf] = [
    16,
    32,
    64,
]

supported_kv_cache_dtypes class-attribute

supported_kv_cache_dtypes: list[CacheDType] = [
    "auto",
    "fp8",
    "fp8_e4m3",
    "fp8_e5m2",
]

get_builder_cls staticmethod

get_builder_cls() -> type[FlashInferMetadataBuilder]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_builder_cls() -> type["FlashInferMetadataBuilder"]:
    return FlashInferMetadataBuilder

get_fp8_dtype_for_flashinfer staticmethod

get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> dtype
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
    if kv_cache_dtype in ("fp8", "fp8_e4m3"):
        return torch.float8_e4m3fn
    elif kv_cache_dtype == "fp8_e5m2":
        return torch.float8_e5m2
    else:
        raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

get_impl_cls staticmethod

get_impl_cls() -> type[FlashInferImpl]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_impl_cls() -> type["FlashInferImpl"]:
    return FlashInferImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    return (num_blocks, 2, block_size, num_kv_heads, head_size)

get_kv_cache_stride_order staticmethod

get_kv_cache_stride_order() -> tuple[int, ...]
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
    # `stride_order` indicates the permutation that gets us from
    # `get_kv_cache_shape` to the actual memory layout we want.
    cache_layout = get_kv_cache_layout()
    if cache_layout == "NHD":
        stride_order = (0, 1, 2, 3, 4)
    elif cache_layout == "HND":
        stride_order = (0, 1, 3, 2, 4)
    else:
        raise ValueError(f"Unknown cache layout format {cache_layout}.")
    return stride_order

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/flashinfer.py
@staticmethod
def get_name() -> str:
    return "FLASHINFER"

get_required_kv_cache_layout classmethod

get_required_kv_cache_layout() -> KVCacheLayoutType | None
Source code in vllm/v1/attention/backends/flashinfer.py
@classmethod
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
    from vllm.platforms import current_platform

    capability = current_platform.get_device_capability()
    if capability is not None and capability.major == 10:
        return "HND"
    return None

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/flashinfer.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
    return [64, 128, 256]

supports_compute_capability classmethod

supports_compute_capability(
    capability: DeviceCapability,
) -> bool
Source code in vllm/v1/attention/backends/flashinfer.py
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
    return capability >= DeviceCapability(7, 5) and capability <= DeviceCapability(
        12, 1
    )

supports_sink classmethod

supports_sink() -> bool

FlashInfer supports sinks when TRTLLM attention is available (SM100).

Source code in vllm/v1/attention/backends/flashinfer.py
@classmethod
def supports_sink(cls) -> bool:
    """FlashInfer supports sinks when TRTLLM attention is available (SM100)."""
    from vllm.utils.flashinfer import (
        force_use_trtllm_attention,
        supports_trtllm_attention,
    )

    # Respect explicit disable flag (e.g., VLLM_USE_TRTLLM_ATTENTION=0)
    if force_use_trtllm_attention() is False:
        return False

    # Check if TRTLLM is supported on this platform
    return supports_trtllm_attention()

FlashInferImpl

Bases: AttentionImpl

Source code in vllm/v1/attention/backends/flashinfer.py
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
class FlashInferImpl(AttentionImpl):
    can_return_lse_for_decode: bool = True

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None = None,
        attn_type: AttentionType = AttentionType.DECODER,
        kv_sharing_target_layer_name: int | None = None,
        sinks: torch.Tensor | None = None,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
        self.window_left = (
            self.sliding_window[0] if self.sliding_window is not None else -1
        )
        self.kv_cache_dtype = kv_cache_dtype
        self.logits_soft_cap = logits_soft_cap
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashInferImpl"
            )

        self.sinks: torch.Tensor | None = None
        if sinks is not None:
            if sinks.shape[0] != num_heads:
                raise ValueError(
                    "Sinks must have the same number of heads as the number of "
                    f"heads in the layer. Expected {num_heads}, but got "
                    f"{sinks.shape[0]}."
                )
            self.sinks = sinks

        self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
        self.bmm1_scale: float | None = None
        self.bmm2_scale: float | None = None
        self.o_sf_scale: float | None = None

    def fused_output_quant_supported(self, quant_key: QuantKey):
        return (
            self.support_trtllm_attn
            and self.kv_cache_dtype.startswith("fp8")
            and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
        )

    def supports_quant_query_input(self) -> bool:
        if flashinfer_disable_q_quantization():
            return False

        return self.support_trtllm_attn

    # FlashInfer requires attention sinks to be float32
    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if self.sinks is not None and self.sinks.dtype != torch.float32:
            self.sinks = self.sinks.to(torch.float32)

    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashInferMetadata,
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass with FlashInfer.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
            kv_cache: KV cache tensor with different possible shapes:
                - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
                - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        assert output is not None, "Output tensor must be provided."

        if attn_metadata is None:
            # Profiling run.
            return output.fill_(0)

        # Ensure query dtype matches the expected dtype from attention metadata
        assert attn_metadata.q_data_type == query.dtype, (
            f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
            f"got {query.dtype}"
        )

        if self.bmm1_scale is None:
            self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale

        if self.bmm2_scale is None:
            self.bmm2_scale = layer._v_scale_float

        # The attn+quant fusion happens when output_scale is provided.
        if output_scale is None:
            assert output_block_scale is None, (
                "output_block_scale is not supported when fusion has not happened"
            )
        else:
            assert attn_metadata.q_data_type == FP8_DTYPE, (
                "Query must be FP8 when attn+quant fusion happened."
            )
            assert (
                attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm
            ), "Must use TRT-LLM attn"

            if output.dtype == FP8_DTYPE:
                assert output_block_scale is None, (
                    "output_block_scale should not be provided for fp8 output"
                )
            elif output.dtype == FP4_DTYPE:
                assert output_block_scale is not None, (
                    "output_block_scale is required for nvfp4 output"
                )
            else:
                raise ValueError(f"Unsupported output dtype: {output.dtype}")

            # TRTLLM attn kernel requires to scale to pass as a host scalar,
            # store the o scale as a host scalar in warmup run with cuda graph
            # not enabled
            if layer._o_scale_float is None:
                layer._o_scale_float = output_scale.cpu().item()
                if output.dtype == FP8_DTYPE:
                    self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
                elif output.dtype == FP4_DTYPE:
                    self.o_sf_scale = layer._o_scale_float

        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.

        num_actual_tokens = attn_metadata.num_actual_tokens

        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[:, 0],
                kv_cache[:, 1],
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )

            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
            # to process the cache when the kv_cache_dtype is fp8
            if self.kv_cache_dtype.startswith("fp8"):
                torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                    self.kv_cache_dtype
                )
                kv_cache = kv_cache.view(torch_dtype)

        # Inputs and outputs may be padded for CUDA graphs
        query = query[:num_actual_tokens]
        key = key[:num_actual_tokens]
        value = value[:num_actual_tokens]
        output_padded = output
        output = output[:num_actual_tokens]

        if attn_metadata.use_cascade:
            # Cascade attention (rare case).
            assert attn_metadata.cascade_wrapper is not None
            output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
            return output

        # When using spec decoding, num_decodes can be < num_decode_tokens
        # because some decode requests may have more than one query token.
        num_decodes = attn_metadata.num_decodes
        num_decode_tokens = attn_metadata.num_decode_tokens
        num_prefill_tokens = attn_metadata.num_prefill_tokens

        stride_order = FlashInferBackend.get_kv_cache_stride_order()
        kv_cache_permute = kv_cache.permute(*stride_order)
        # Regular attention (common case).
        # Decodes are at the front and prefills are at the back.
        if num_prefill_tokens > 0:
            prefill_wrapper = attn_metadata.prefill_wrapper
            prefill_query = query[num_decode_tokens:]
            assert prefill_query.shape[0] == num_prefill_tokens
            assert prefill_wrapper is not None

            if not attn_metadata.prefill_use_trtllm:
                if self.dcp_world_size > 1:
                    assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
                    assert prefill_wrapper._context._window_left == self.window_left
                    assert prefill_wrapper._context._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._context._sm_scale == self.scale
                    assert not prefill_wrapper._context._causal
                    assert prefill_wrapper._new_tokens._window_left == self.window_left
                    assert prefill_wrapper._new_tokens._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._new_tokens._sm_scale == self.scale
                    assert prefill_wrapper._new_tokens._causal

                    prefill_wrapper.run(
                        layer,
                        prefill_query,
                        kv_cache_permute,
                        key[num_decode_tokens:],
                        value[num_decode_tokens:],
                        out=output[num_decode_tokens:],
                    )
                else:
                    assert isinstance(
                        prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper
                    )
                    assert prefill_wrapper._window_left == self.window_left
                    assert prefill_wrapper._logits_soft_cap == (
                        self.logits_soft_cap or 0.0
                    )
                    assert prefill_wrapper._sm_scale == self.scale
                    assert prefill_wrapper._causal
                    prefill_wrapper.run(
                        prefill_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output[num_decode_tokens:],
                    )
            else:
                # prefill_query may be non-contiguous
                prefill_query = prefill_query.contiguous()
                workspace_buffer = _get_trtllm_gen_workspace_buffer()
                block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:]
                seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]

                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
                assert get_kv_cache_layout() == "HND"
                assert prefill_query.is_contiguous()
                assert kv_cache_permute.is_contiguous()
                assert workspace_buffer.is_contiguous()
                assert block_tables_prefill.is_contiguous()
                assert seq_lens_prefill.is_contiguous()

                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
                    out = FP4Tensor(
                        data=output[num_decode_tokens:],
                        scale=output_block_scale,
                        scale_start_index=num_decode_tokens,
                        original_shape=prefill_query.shape,
                    )
                else:
                    assert self.o_sf_scale is None
                    out = output[num_decode_tokens:]

                if (
                    attn_metadata.q_data_type != FP8_DTYPE
                    and self.kv_cache_dtype.startswith("fp8")
                ):
                    # TRTLLM prefill attention does not support BF16 Q
                    # and fp8 kv cache. So to enable prefill attention
                    # with fp8 kv cache, we can construct a mock block
                    # and mock kv cache with BF16 KV involved in the prefill
                    mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
                        kv_cache_permute,
                        block_tables_prefill,
                        layer._k_scale,
                        layer._v_scale,
                        attn_metadata.q_data_type,
                    )
                else:
                    mock_kv_cache = kv_cache_permute
                    mock_block_table = block_tables_prefill

                trtllm_batch_context_with_kv_cache(
                    query=prefill_query,
                    kv_cache=mock_kv_cache,
                    workspace_buffer=workspace_buffer,
                    block_tables=mock_block_table,
                    seq_lens=seq_lens_prefill,
                    max_q_len=attn_metadata.max_q_len_prefill,
                    max_kv_len=attn_metadata.max_seq_len,
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
                    batch_size=attn_metadata.num_prefills,
                    cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
                    cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
                    window_left=self.window_left,
                    sinks=self.sinks,
                    o_sf_scale=self.o_sf_scale,
                    out=out,
                )

        if num_decode_tokens > 0:
            decode_wrapper = attn_metadata.decode_wrapper
            decode_query = query[:num_decode_tokens]
            assert decode_query.shape[0] == num_decode_tokens
            assert decode_wrapper is not None

            if not attn_metadata.decode_use_trtllm:
                assert decode_wrapper._window_left == self.window_left
                assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
                assert decode_wrapper._sm_scale == self.scale

                if self.dcp_world_size > 1:
                    decode_query = get_dcp_group().all_gather(
                        decode_query.contiguous(), dim=-2
                    )
                    output_tmp = torch.empty_like(decode_query)
                    lse = torch.empty(
                        (decode_query.size(0), decode_query.size(1)),
                        dtype=torch.float32,
                        device=decode_query.device,
                    )
                    decode_wrapper.run(
                        decode_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output_tmp,
                        lse=lse,
                        return_lse=True,
                    )
                    output[:num_decode_tokens] = cp_lse_ag_out_rs(
                        output_tmp, lse, get_dcp_group()
                    )
                else:
                    decode_wrapper.run(
                        decode_query,
                        kv_cache_permute,
                        k_scale=layer._k_scale_float,
                        v_scale=layer._v_scale_float,
                        out=output[:num_decode_tokens],
                    )
            else:
                # decode_query may be non-contiguous
                decode_query = decode_query.contiguous()
                workspace_buffer = _get_trtllm_gen_workspace_buffer()
                block_tables_decode = attn_metadata.block_table_tensor[
                    :num_decode_tokens
                ]
                seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]

                # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
                assert get_kv_cache_layout() == "HND"
                assert decode_query.is_contiguous()
                assert kv_cache_permute.is_contiguous()
                assert workspace_buffer.is_contiguous()
                assert block_tables_decode.is_contiguous()
                assert seq_lens_decode.is_contiguous()

                if output.dtype == FP4_DTYPE:
                    assert self.o_sf_scale is not None
                    out = FP4Tensor(
                        data=output[:num_decode_tokens],
                        scale=output_block_scale,
                        scale_start_index=0,
                        original_shape=decode_query.shape,
                    )
                else:
                    assert self.o_sf_scale is None
                    out = output[:num_decode_tokens]

                if num_decode_tokens % attn_metadata.num_decodes != 0:
                    # This gets triggered when the dummy_run forces
                    # attention to be initialized with q_len = 0
                    q_len_per_req = 1
                else:
                    q_len_per_req = num_decode_tokens // attn_metadata.num_decodes

                trtllm_batch_decode_with_kv_cache(
                    query=decode_query,
                    kv_cache=kv_cache_permute,
                    workspace_buffer=workspace_buffer,
                    block_tables=block_tables_decode,
                    seq_lens=seq_lens_decode,
                    max_seq_len=attn_metadata.max_seq_len,
                    bmm1_scale=self.bmm1_scale,
                    bmm2_scale=self.bmm2_scale,
                    window_left=self.window_left,
                    sinks=self.sinks,
                    o_sf_scale=self.o_sf_scale,
                    out=out,
                    q_len_per_req=q_len_per_req,
                )
        return output_padded

alibi_slopes instance-attribute

alibi_slopes = alibi_slopes

bmm1_scale instance-attribute

bmm1_scale: float | None = None

bmm2_scale instance-attribute

bmm2_scale: float | None = None

can_return_lse_for_decode class-attribute instance-attribute

can_return_lse_for_decode: bool = True

head_size instance-attribute

head_size = head_size

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_sharing_target_layer_name instance-attribute

kv_sharing_target_layer_name = kv_sharing_target_layer_name

logits_soft_cap instance-attribute

logits_soft_cap = logits_soft_cap

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

o_sf_scale instance-attribute

o_sf_scale: float | None = None

scale instance-attribute

scale = float(scale)

sinks instance-attribute

sinks: Tensor | None = None

sliding_window instance-attribute

sliding_window = (-1, -1)

support_trtllm_attn instance-attribute

support_trtllm_attn = can_use_trtllm_attention(
    num_heads, num_kv_heads
)

window_left instance-attribute

window_left = (
    sliding_window[0] if sliding_window is not None else -1
)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None = None,
    attn_type: AttentionType = DECODER,
    kv_sharing_target_layer_name: int | None = None,
    sinks: Tensor | None = None,
) -> None
Source code in vllm/v1/attention/backends/flashinfer.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None = None,
    attn_type: AttentionType = AttentionType.DECODER,
    kv_sharing_target_layer_name: int | None = None,
    sinks: torch.Tensor | None = None,
) -> None:
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    if alibi_slopes is not None:
        alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
    self.alibi_slopes = alibi_slopes
    if sliding_window is None:
        self.sliding_window = (-1, -1)
    else:
        self.sliding_window = (sliding_window - 1, 0)
    self.window_left = (
        self.sliding_window[0] if self.sliding_window is not None else -1
    )
    self.kv_cache_dtype = kv_cache_dtype
    self.logits_soft_cap = logits_soft_cap
    self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError(
            "Encoder self-attention and "
            "encoder/decoder cross-attention "
            "are not implemented for "
            "FlashInferImpl"
        )

    self.sinks: torch.Tensor | None = None
    if sinks is not None:
        if sinks.shape[0] != num_heads:
            raise ValueError(
                "Sinks must have the same number of heads as the number of "
                f"heads in the layer. Expected {num_heads}, but got "
                f"{sinks.shape[0]}."
            )
        self.sinks = sinks

    self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
    self.bmm1_scale: float | None = None
    self.bmm2_scale: float | None = None
    self.o_sf_scale: float | None = None

forward

forward(
    layer: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: FlashInferMetadata,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor

Forward pass with FlashInfer.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads, head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
kv_cache Tensor

KV cache tensor with different possible shapes: - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]

required
attn_metadata FlashInferMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size]

Source code in vllm/v1/attention/backends/flashinfer.py
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
def forward(
    self,
    layer: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: FlashInferMetadata,
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    """Forward pass with FlashInfer.

    Args:
        query: shape = [num_tokens, num_heads, head_size]
        key: shape = [num_tokens, num_kv_heads, head_size]
        value: shape = [num_tokens, num_kv_heads, head_size]
        kv_cache: KV cache tensor with different possible shapes:
            - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
            - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    assert output is not None, "Output tensor must be provided."

    if attn_metadata is None:
        # Profiling run.
        return output.fill_(0)

    # Ensure query dtype matches the expected dtype from attention metadata
    assert attn_metadata.q_data_type == query.dtype, (
        f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
        f"got {query.dtype}"
    )

    if self.bmm1_scale is None:
        self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale

    if self.bmm2_scale is None:
        self.bmm2_scale = layer._v_scale_float

    # The attn+quant fusion happens when output_scale is provided.
    if output_scale is None:
        assert output_block_scale is None, (
            "output_block_scale is not supported when fusion has not happened"
        )
    else:
        assert attn_metadata.q_data_type == FP8_DTYPE, (
            "Query must be FP8 when attn+quant fusion happened."
        )
        assert (
            attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm
        ), "Must use TRT-LLM attn"

        if output.dtype == FP8_DTYPE:
            assert output_block_scale is None, (
                "output_block_scale should not be provided for fp8 output"
            )
        elif output.dtype == FP4_DTYPE:
            assert output_block_scale is not None, (
                "output_block_scale is required for nvfp4 output"
            )
        else:
            raise ValueError(f"Unsupported output dtype: {output.dtype}")

        # TRTLLM attn kernel requires to scale to pass as a host scalar,
        # store the o scale as a host scalar in warmup run with cuda graph
        # not enabled
        if layer._o_scale_float is None:
            layer._o_scale_float = output_scale.cpu().item()
            if output.dtype == FP8_DTYPE:
                self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
            elif output.dtype == FP4_DTYPE:
                self.o_sf_scale = layer._o_scale_float

    # IMPORTANT!
    # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
    # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
    # in this method. For example, `view` and `slice` (or `[:n]`) operations
    # are surprisingly slow even in the case they do not invoke any GPU ops.
    # Minimize the PyTorch ops in this method as much as possible.
    # Whenever making a change in this method, please benchmark the
    # performance to make sure it does not introduce any overhead.

    num_actual_tokens = attn_metadata.num_actual_tokens

    if self.kv_sharing_target_layer_name is None:
        # Reshape the input keys and values and store them in the cache.
        # Skip this if sharing KV cache with an earlier attention layer.
        # NOTE(woosuk): Here, key and value are padded while slot_mapping is
        # not padded. However, we don't need to do key[:num_actual_tokens]
        # and value[:num_actual_tokens] because the reshape_and_cache_flash
        # op uses the slot_mapping's shape to determine the number of
        # actual tokens.
        torch.ops._C_cache_ops.reshape_and_cache_flash(
            key,
            value,
            kv_cache[:, 0],
            kv_cache[:, 1],
            attn_metadata.slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

        # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
        # to process the cache when the kv_cache_dtype is fp8
        if self.kv_cache_dtype.startswith("fp8"):
            torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.kv_cache_dtype
            )
            kv_cache = kv_cache.view(torch_dtype)

    # Inputs and outputs may be padded for CUDA graphs
    query = query[:num_actual_tokens]
    key = key[:num_actual_tokens]
    value = value[:num_actual_tokens]
    output_padded = output
    output = output[:num_actual_tokens]

    if attn_metadata.use_cascade:
        # Cascade attention (rare case).
        assert attn_metadata.cascade_wrapper is not None
        output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
        return output

    # When using spec decoding, num_decodes can be < num_decode_tokens
    # because some decode requests may have more than one query token.
    num_decodes = attn_metadata.num_decodes
    num_decode_tokens = attn_metadata.num_decode_tokens
    num_prefill_tokens = attn_metadata.num_prefill_tokens

    stride_order = FlashInferBackend.get_kv_cache_stride_order()
    kv_cache_permute = kv_cache.permute(*stride_order)
    # Regular attention (common case).
    # Decodes are at the front and prefills are at the back.
    if num_prefill_tokens > 0:
        prefill_wrapper = attn_metadata.prefill_wrapper
        prefill_query = query[num_decode_tokens:]
        assert prefill_query.shape[0] == num_prefill_tokens
        assert prefill_wrapper is not None

        if not attn_metadata.prefill_use_trtllm:
            if self.dcp_world_size > 1:
                assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
                assert prefill_wrapper._context._window_left == self.window_left
                assert prefill_wrapper._context._logits_soft_cap == (
                    self.logits_soft_cap or 0.0
                )
                assert prefill_wrapper._context._sm_scale == self.scale
                assert not prefill_wrapper._context._causal
                assert prefill_wrapper._new_tokens._window_left == self.window_left
                assert prefill_wrapper._new_tokens._logits_soft_cap == (
                    self.logits_soft_cap or 0.0
                )
                assert prefill_wrapper._new_tokens._sm_scale == self.scale
                assert prefill_wrapper._new_tokens._causal

                prefill_wrapper.run(
                    layer,
                    prefill_query,
                    kv_cache_permute,
                    key[num_decode_tokens:],
                    value[num_decode_tokens:],
                    out=output[num_decode_tokens:],
                )
            else:
                assert isinstance(
                    prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper
                )
                assert prefill_wrapper._window_left == self.window_left
                assert prefill_wrapper._logits_soft_cap == (
                    self.logits_soft_cap or 0.0
                )
                assert prefill_wrapper._sm_scale == self.scale
                assert prefill_wrapper._causal
                prefill_wrapper.run(
                    prefill_query,
                    kv_cache_permute,
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
                    out=output[num_decode_tokens:],
                )
        else:
            # prefill_query may be non-contiguous
            prefill_query = prefill_query.contiguous()
            workspace_buffer = _get_trtllm_gen_workspace_buffer()
            block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:]
            seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]

            # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
            assert get_kv_cache_layout() == "HND"
            assert prefill_query.is_contiguous()
            assert kv_cache_permute.is_contiguous()
            assert workspace_buffer.is_contiguous()
            assert block_tables_prefill.is_contiguous()
            assert seq_lens_prefill.is_contiguous()

            if output.dtype == FP4_DTYPE:
                assert self.o_sf_scale is not None
                out = FP4Tensor(
                    data=output[num_decode_tokens:],
                    scale=output_block_scale,
                    scale_start_index=num_decode_tokens,
                    original_shape=prefill_query.shape,
                )
            else:
                assert self.o_sf_scale is None
                out = output[num_decode_tokens:]

            if (
                attn_metadata.q_data_type != FP8_DTYPE
                and self.kv_cache_dtype.startswith("fp8")
            ):
                # TRTLLM prefill attention does not support BF16 Q
                # and fp8 kv cache. So to enable prefill attention
                # with fp8 kv cache, we can construct a mock block
                # and mock kv cache with BF16 KV involved in the prefill
                mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
                    kv_cache_permute,
                    block_tables_prefill,
                    layer._k_scale,
                    layer._v_scale,
                    attn_metadata.q_data_type,
                )
            else:
                mock_kv_cache = kv_cache_permute
                mock_block_table = block_tables_prefill

            trtllm_batch_context_with_kv_cache(
                query=prefill_query,
                kv_cache=mock_kv_cache,
                workspace_buffer=workspace_buffer,
                block_tables=mock_block_table,
                seq_lens=seq_lens_prefill,
                max_q_len=attn_metadata.max_q_len_prefill,
                max_kv_len=attn_metadata.max_seq_len,
                bmm1_scale=self.bmm1_scale,
                bmm2_scale=self.bmm2_scale,
                batch_size=attn_metadata.num_prefills,
                cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
                cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
                window_left=self.window_left,
                sinks=self.sinks,
                o_sf_scale=self.o_sf_scale,
                out=out,
            )

    if num_decode_tokens > 0:
        decode_wrapper = attn_metadata.decode_wrapper
        decode_query = query[:num_decode_tokens]
        assert decode_query.shape[0] == num_decode_tokens
        assert decode_wrapper is not None

        if not attn_metadata.decode_use_trtllm:
            assert decode_wrapper._window_left == self.window_left
            assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
            assert decode_wrapper._sm_scale == self.scale

            if self.dcp_world_size > 1:
                decode_query = get_dcp_group().all_gather(
                    decode_query.contiguous(), dim=-2
                )
                output_tmp = torch.empty_like(decode_query)
                lse = torch.empty(
                    (decode_query.size(0), decode_query.size(1)),
                    dtype=torch.float32,
                    device=decode_query.device,
                )
                decode_wrapper.run(
                    decode_query,
                    kv_cache_permute,
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
                    out=output_tmp,
                    lse=lse,
                    return_lse=True,
                )
                output[:num_decode_tokens] = cp_lse_ag_out_rs(
                    output_tmp, lse, get_dcp_group()
                )
            else:
                decode_wrapper.run(
                    decode_query,
                    kv_cache_permute,
                    k_scale=layer._k_scale_float,
                    v_scale=layer._v_scale_float,
                    out=output[:num_decode_tokens],
                )
        else:
            # decode_query may be non-contiguous
            decode_query = decode_query.contiguous()
            workspace_buffer = _get_trtllm_gen_workspace_buffer()
            block_tables_decode = attn_metadata.block_table_tensor[
                :num_decode_tokens
            ]
            seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]

            # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
            assert get_kv_cache_layout() == "HND"
            assert decode_query.is_contiguous()
            assert kv_cache_permute.is_contiguous()
            assert workspace_buffer.is_contiguous()
            assert block_tables_decode.is_contiguous()
            assert seq_lens_decode.is_contiguous()

            if output.dtype == FP4_DTYPE:
                assert self.o_sf_scale is not None
                out = FP4Tensor(
                    data=output[:num_decode_tokens],
                    scale=output_block_scale,
                    scale_start_index=0,
                    original_shape=decode_query.shape,
                )
            else:
                assert self.o_sf_scale is None
                out = output[:num_decode_tokens]

            if num_decode_tokens % attn_metadata.num_decodes != 0:
                # This gets triggered when the dummy_run forces
                # attention to be initialized with q_len = 0
                q_len_per_req = 1
            else:
                q_len_per_req = num_decode_tokens // attn_metadata.num_decodes

            trtllm_batch_decode_with_kv_cache(
                query=decode_query,
                kv_cache=kv_cache_permute,
                workspace_buffer=workspace_buffer,
                block_tables=block_tables_decode,
                seq_lens=seq_lens_decode,
                max_seq_len=attn_metadata.max_seq_len,
                bmm1_scale=self.bmm1_scale,
                bmm2_scale=self.bmm2_scale,
                window_left=self.window_left,
                sinks=self.sinks,
                o_sf_scale=self.o_sf_scale,
                out=out,
                q_len_per_req=q_len_per_req,
            )
    return output_padded

fused_output_quant_supported

fused_output_quant_supported(quant_key: QuantKey)
Source code in vllm/v1/attention/backends/flashinfer.py
def fused_output_quant_supported(self, quant_key: QuantKey):
    return (
        self.support_trtllm_attn
        and self.kv_cache_dtype.startswith("fp8")
        and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
    )

process_weights_after_loading

process_weights_after_loading(act_dtype: dtype)
Source code in vllm/v1/attention/backends/flashinfer.py
def process_weights_after_loading(self, act_dtype: torch.dtype):
    if self.sinks is not None and self.sinks.dtype != torch.float32:
        self.sinks = self.sinks.to(torch.float32)

supports_quant_query_input

supports_quant_query_input() -> bool
Source code in vllm/v1/attention/backends/flashinfer.py
def supports_quant_query_input(self) -> bool:
    if flashinfer_disable_q_quantization():
        return False

    return self.support_trtllm_attn

FlashInferMetadata dataclass

Source code in vllm/v1/attention/backends/flashinfer.py
@dataclass
class FlashInferMetadata:
    num_actual_tokens: int  # Number of tokens excluding padding.

    # The data type of the query
    q_data_type: torch.dtype

    slot_mapping: torch.Tensor

    # For flashinfer trtllm batch decode
    max_q_len: int
    max_q_len_prefill: int
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table_tensor: torch.Tensor
    prefill_use_trtllm: bool
    decode_use_trtllm: bool

    # For handling prefill decode split
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

    # For cascade attention (CPU for planning).
    use_cascade: bool

    prefill_wrapper: (
        BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
    ) = None
    decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
    cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None

    qo_indptr_gpu: torch.Tensor | None = None
    paged_kv_indptr_gpu: torch.Tensor | None = None

block_table_tensor instance-attribute

block_table_tensor: Tensor

cascade_wrapper class-attribute instance-attribute

cascade_wrapper: (
    MultiLevelCascadeAttentionWrapper | None
) = None

decode_use_trtllm instance-attribute

decode_use_trtllm: bool

decode_wrapper class-attribute instance-attribute

decode_wrapper: (
    BatchDecodeWithPagedKVCacheWrapper | None
) = None

max_q_len instance-attribute

max_q_len: int

max_q_len_prefill instance-attribute

max_q_len_prefill: int

max_seq_len instance-attribute

max_seq_len: int

num_actual_tokens instance-attribute

num_actual_tokens: int

num_decode_tokens instance-attribute

num_decode_tokens: int

num_decodes instance-attribute

num_decodes: int

num_prefill_tokens instance-attribute

num_prefill_tokens: int

num_prefills instance-attribute

num_prefills: int

paged_kv_indptr_gpu class-attribute instance-attribute

paged_kv_indptr_gpu: Tensor | None = None

prefill_use_trtllm instance-attribute

prefill_use_trtllm: bool

prefill_wrapper class-attribute instance-attribute

prefill_wrapper: (
    BatchPrefillWithPagedKVCacheWrapper
    | BatchDCPPrefillWrapper
    | None
) = None

q_data_type instance-attribute

q_data_type: dtype

qo_indptr_gpu class-attribute instance-attribute

qo_indptr_gpu: Tensor | None = None

seq_lens instance-attribute

seq_lens: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

use_cascade instance-attribute

use_cascade: bool

__init__

__init__(
    num_actual_tokens: int,
    q_data_type: dtype,
    slot_mapping: Tensor,
    max_q_len: int,
    max_q_len_prefill: int,
    max_seq_len: int,
    seq_lens: Tensor,
    block_table_tensor: Tensor,
    prefill_use_trtllm: bool,
    decode_use_trtllm: bool,
    num_decodes: int,
    num_decode_tokens: int,
    num_prefills: int,
    num_prefill_tokens: int,
    use_cascade: bool,
    prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper
    | BatchDCPPrefillWrapper
    | None = None,
    decode_wrapper: BatchDecodeWithPagedKVCacheWrapper
    | None = None,
    cascade_wrapper: MultiLevelCascadeAttentionWrapper
    | None = None,
    qo_indptr_gpu: Tensor | None = None,
    paged_kv_indptr_gpu: Tensor | None = None,
) -> None

FlashInferMetadataBuilder

Bases: AttentionMetadataBuilder[FlashInferMetadata]

Source code in vllm/v1/attention/backends/flashinfer.py
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
    reorder_batch_threshold: int = 1

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
        self.cache_config = vllm_config.cache_config
        self.model_config = vllm_config.model_config
        self._workspace_buffer = None
        self._prefill_wrapper: (
            BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
        ) = None  # Wrapper for prefill/append
        self._decode_wrapper = None  # Wrapper for decode (general shape)

        if vllm_is_batch_invariant():
            self.decode_fixed_split_size = 2048
            self.prefill_fixed_split_size = 4096
            self.disable_split_kv = True
        else:
            self.decode_fixed_split_size = -1
            self.prefill_fixed_split_size = -1
            self.disable_split_kv = False

        self.compilation_config = vllm_config.compilation_config
        max_num_pages_per_req = cdiv(
            self.model_config.max_model_len, self.kv_cache_spec.block_size
        )
        max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        max_num_pages = max_num_reqs * max_num_pages_per_req
        speculative_config = vllm_config.speculative_config
        num_spec_tokens = (
            speculative_config.num_speculative_tokens
            if speculative_config is not None
            else 0
        )
        self.enable_cuda_graph = (
            self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
        )
        if self.enable_cuda_graph:
            # For full cudagraph capture, one `decode_wrapper` for each batch
            # size is needed for FlashInfer.
            self._decode_wrappers_cudagraph: dict[
                int, BatchDecodeWithPagedKVCacheWrapper
            ] = {}
            self._decode_cudagraph_max_bs = min(
                (1 + num_spec_tokens) * max_num_reqs,
                self.compilation_config.max_cudagraph_capture_size,
            )

        try:
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
            self.dcp_kv_cache_interleave_size = (
                vllm_config.parallel_config.dcp_kv_cache_interleave_size
            )
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
            self.dcp_kv_cache_interleave_size = 1

        self.num_qo_heads = (
            self.model_config.get_num_attention_heads(self.vllm_config.parallel_config)
            * self.dcp_world_size
        )

        self.num_kv_heads = self.kv_cache_spec.num_kv_heads
        self.head_dim = self.kv_cache_spec.head_size
        self.page_size = self.kv_cache_spec.block_size

        self.cache_dtype = self.cache_config.cache_dtype
        if self.cache_dtype.startswith("fp8"):
            self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
                self.cache_dtype
            )
        else:
            assert self.kv_cache_spec.dtype == self.model_config.dtype
            self.kv_cache_dtype = self.kv_cache_spec.dtype

        # Use model dtype as q dtype when TRTLLM attn is not supported, or
        # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
        # use fp8 q if kv cache is fp8, and will fall back to model dtype
        # if TRTLLM attention kernel is not used when building attn metadata
        can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
        if can_use_trtllm and not flashinfer_disable_q_quantization():
            self.q_data_type = self.kv_cache_dtype
        else:
            self.q_data_type = self.model_config.dtype

        # Prefer TRTLLM attention for decoding in all cases.
        # This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
        self.use_trtllm_decode_attention = can_use_trtllm
        self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)

        self._cascade_wrapper = None  # Wrapper for cascade attention

        # Global hyperparameters shared by all attention layers
        # TODO: discard this for trtllm-gen backend
        self.global_hyperparameters = infer_global_hyperparameters(
            get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)
        )
        self.sm_scale = self.global_hyperparameters.sm_scale
        self.window_left = self.global_hyperparameters.window_left
        self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
        self.has_sinks = self.global_hyperparameters.has_sinks
        if self.has_sinks and not can_use_trtllm:
            raise NotImplementedError(
                "FlashInfer backend currently does not support attention "
                "sinks, please use trtllm on blackwell or flash attention on "
                "earlier GPUs."
            )
        # Preparing persistent buffers (device-side)
        self.paged_kv_indptr = torch.zeros(
            max_num_reqs + 1, dtype=torch.int32, device=self.device
        )
        self.paged_kv_indices = torch.zeros(
            max_num_pages,  # max num pages possible
            dtype=torch.int32,
            device=self.device,
        )
        self.paged_kv_last_page_len = torch.zeros(
            max_num_reqs, dtype=torch.int32, device=self.device
        )
        # host-side buffer
        pin_memory = is_pin_memory_available()
        self.paged_kv_indptr_cpu = torch.zeros(
            max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
        self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
        self.paged_kv_indptr_buffer = torch.zeros_like(
            self.paged_kv_indptr_cpu, pin_memory=pin_memory
        )
        self.paged_kv_indices_cpu = torch.zeros(
            max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
        self.paged_kv_last_page_len_cpu = torch.zeros(
            max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory
        )
        self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()

        if self.head_dim == 256 and current_platform.is_device_capability(100):
            # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
            # head size 256 and block size 16 is not supported on blackwell.
            assert kv_cache_spec.block_size != 16, (
                "There is a bug in FlashInfer "
                "block_size 16 head size 256 support. Please avoid this combination by "
                "passing --block-size 32 or --block-size 64."
            )

    @classmethod
    @override
    def get_cudagraph_support(
        cls: type["FlashInferMetadataBuilder"],
        vllm_config: VllmConfig,
        kv_cache_spec: AttentionSpec,
    ) -> AttentionCGSupport:
        has_trtllm_support = can_use_trtllm_attention(
            num_qo_heads=vllm_config.model_config.get_num_attention_heads(
                vllm_config.parallel_config
            ),
            num_kv_heads=kv_cache_spec.num_kv_heads,
        )
        if has_trtllm_support:
            return AttentionCGSupport.UNIFORM_BATCH
        else:
            return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

    def _get_workspace_buffer(self):
        if self._workspace_buffer is None:
            buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
            if vllm_is_batch_invariant():
                buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
            self._workspace_buffer = torch.zeros(
                buffer_size, dtype=torch.uint8, device=self.device
            )
        return self._workspace_buffer

    def _get_prefill_wrapper(
        self,
    ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
        if self._prefill_wrapper is None:
            if self.dcp_world_size > 1:
                self._prefill_wrapper = BatchDCPPrefillWrapper(
                    workspace_buffer=self._get_workspace_buffer(),
                )
            else:
                self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
                    self._get_workspace_buffer(), get_kv_cache_layout()
                )
        assert self._prefill_wrapper is not None
        return self._prefill_wrapper

    def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False):
        if use_cudagraph:
            decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None)
        else:
            decode_wrapper = self._decode_wrapper

        if decode_wrapper is None:
            if use_cudagraph:
                paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1]
                paged_kv_indices = self.paged_kv_indices
                paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size]
            else:
                paged_kv_indptr = None
                paged_kv_indices = None
                paged_kv_last_page_len = None
            decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
                self._get_workspace_buffer(),
                get_kv_cache_layout(),
                use_cuda_graph=use_cudagraph,
                paged_kv_indptr_buffer=paged_kv_indptr,
                paged_kv_indices_buffer=paged_kv_indices,
                paged_kv_last_page_len_buffer=paged_kv_last_page_len,
                # Tensor cores are enabled by default because the perf would be
                # at least as good as cuda cores for all attention ops in latest
                # gpus.
                use_tensor_cores=True,
            )

            # save the decode wrapper
            if use_cudagraph:
                self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
            else:
                self._decode_wrapper = decode_wrapper

        return decode_wrapper

    def _get_cascade_wrapper(self):
        if self._cascade_wrapper is None:
            self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
                2, self._get_workspace_buffer(), get_kv_cache_layout()
            )
        return self._cascade_wrapper

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> FlashInferMetadata:
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold,
                require_uniform=True,
            )
        )

        page_size = self.page_size
        max_q_len = common_attn_metadata.max_query_len
        max_seq_len = common_attn_metadata.max_seq_len
        seq_lens = common_attn_metadata.seq_lens
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu
        block_table_tensor = common_attn_metadata.block_table_tensor
        qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu

        if self.dcp_world_size > 1:
            if num_prefills > 0:
                qo_indptr_prefill_cpu = (
                    qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
                )
                query_lens_prefill_cpu = (
                    qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
                )
                seq_lens_cpu[num_decodes:] = (
                    seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
                )

            seq_lens_cpu = get_dcp_local_seq_lens(
                seq_lens_cpu,
                self.dcp_world_size,
                self.dcp_rank,
                self.dcp_kv_cache_interleave_size,
            )

        seq_lens_np = seq_lens_cpu.numpy()
        num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size

        use_cascade = common_prefix_len > 0
        if use_cascade:
            # Grab the blocks of the shared prefix from the first request.
            assert common_prefix_len % page_size == 0
            num_common_kv_blocks = common_prefix_len // page_size

            # Create CPU versions directly for cascade (no GPU versions needed)
            shared_qo_indptr_cpu = torch.tensor(
                [0, num_actual_tokens], dtype=torch.int32, device="cpu"
            )
            shared_kv_page_indptr_cpu = torch.tensor(
                [0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
            )
            shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
            shared_kv_last_page_len_cpu = torch.tensor(
                [page_size], dtype=torch.int32, device="cpu"
            )

            # Remove the blocks of the shared prefix from all requests.
            block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
            num_blocks_np -= num_common_kv_blocks
        else:
            shared_qo_indptr_cpu = None
            shared_kv_page_indptr_cpu = None
            shared_kv_page_indices_cpu = None
            shared_kv_last_page_len_cpu = None

        # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
        np.cumsum(
            num_blocks_np,
            dtype=np.int32,
            out=self.paged_kv_indptr_np[1 : num_reqs + 1],
        )
        # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
        # after this line (e.g., for cuda graphs), we need to copy the data to
        # self.paged_kv_indptr_buffer to avoid race condition.
        self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[
            : num_reqs + 1
        ]
        paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1]
        paged_kv_indptr.copy_(
            self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True
        )

        # write self.paged_kv_indices inplace
        num_actual_pages = self.paged_kv_indptr_np[num_reqs]
        paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
        _copy_page_indices_kernel[(num_reqs,)](
            paged_kv_indices,
            block_table_tensor,
            block_table_tensor.stride(0),
            paged_kv_indptr,
            BLOCK_SIZE=1024,
        )

        # write self.paged_kv_last_page_len_cpu inplace
        paged_kv_last_page_len_np = seq_lens_np % page_size
        self.paged_kv_last_page_len_np[:num_reqs] = np.where(
            (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
            page_size,
            paged_kv_last_page_len_np,
        )

        uses_spec_reorder = self.reorder_batch_threshold > 1
        prefill_use_trtllm = use_trtllm_attention(
            self.num_qo_heads,
            self.num_kv_heads,
            num_prefill_tokens,
            max_seq_len,
            self.dcp_world_size,
            self.cache_dtype,
            self.q_data_type,
            is_prefill=True,
            has_sinks=self.has_sinks,
            has_spec=uses_spec_reorder,
        )
        decode_use_trtllm = (
            self.use_trtllm_decode_attention and self.dcp_world_size <= 1
        )

        if not (prefill_use_trtllm and decode_use_trtllm):
            if self.has_sinks:
                raise NotImplementedError(
                    "FlashInfer backend currently does not support attention "
                    "sinks, please use trtllm on blackwell or flash attention "
                    "on earlier GPUs."
                )

            if not self.global_hyperparameters.has_same_window_lefts:
                raise ValueError(
                    "Window left is not the same for all layers. "
                    "One potential fix is to set disable_sliding_window=True"
                )

            assert self.global_hyperparameters.has_same_all_params, (
                "FlashInfer backend currently only supports models in which "
                "all layers share the same values for the following "
                "hyperparameters: `window_left`, `logits_soft_cap`, "
                "`sm_scale`."
            )

            # The q quantization is not supported for non-trtllm attention,
            # fall back to model dtype.
            self.q_data_type = self.model_config.dtype

        attn_metadata = FlashInferMetadata(
            num_actual_tokens=num_actual_tokens,
            q_data_type=self.q_data_type,
            slot_mapping=common_attn_metadata.slot_mapping,
            max_q_len=max_q_len,
            max_q_len_prefill=max_q_len,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
            block_table_tensor=block_table_tensor,
            prefill_use_trtllm=prefill_use_trtllm,
            decode_use_trtllm=decode_use_trtllm,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            use_cascade=use_cascade,
        )

        paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs]
        paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]

        if attn_metadata.use_cascade:
            attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
            attn_metadata.cascade_wrapper.plan(
                [shared_qo_indptr_cpu, qo_indptr_cpu],
                [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
                [shared_kv_page_indices_cpu, paged_kv_indices],
                [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
                self.num_qo_heads,
                self.num_kv_heads,
                self.head_dim,
                self.page_size,
                causal=True,
                sm_scale=self.sm_scale,
                window_left=self.window_left,
                logits_soft_cap=self.logits_soft_cap,
                q_data_type=self.q_data_type,
                kv_data_type=self.kv_cache_dtype,
            )
        else:
            # Regular attention (common case).
            # Decodes are at the front and prefills are at the back.
            num_prefills = attn_metadata.num_prefills
            num_decodes = attn_metadata.num_decodes
            if num_prefills > 0:
                # Decodes are first so prefills start after the last decode
                prefill_start = num_decodes
                attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
                assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
                assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
                assert (
                    paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills
                )
                # Since prefill_wrapper.run() will be called with
                # query[num_decode_tokens:] we need to adjust the qo_indptr
                # to be relative to the start of the prefill queries.
                qo_indptr_cpu = (
                    qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
                )
                paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]

                # Recompute max_q_len for the slice of requests we are using
                # for prefills. This can be different from max_q_len when
                # we have a non-uniform batch with some short decodes offloaded
                # to the prefill pathway
                query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
                attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item())

                if not attn_metadata.prefill_use_trtllm:
                    if self.dcp_world_size > 1:
                        assert isinstance(
                            attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper
                        )
                        attn_metadata.prefill_wrapper.plan(
                            qo_indptr_cpu=qo_indptr_cpu,
                            paged_kv_indptr_cpu=paged_kv_indptr_cpu,
                            paged_kv_indices=paged_kv_indices,
                            paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
                            prefill_start=prefill_start,
                            page_size=self.page_size,
                            num_qo_heads=self.num_qo_heads,
                            dcp_world_size=self.dcp_world_size,
                            num_kv_heads=self.num_kv_heads,
                            head_dim=self.head_dim,
                            sm_scale=self.sm_scale,
                            window_left=self.window_left,
                            logits_soft_cap=self.logits_soft_cap,
                            q_data_type=self.q_data_type,
                            kv_cache_dtype=self.kv_cache_dtype,
                            prefill_fixed_split_size=self.prefill_fixed_split_size,
                            disable_split_kv=self.disable_split_kv,
                        )
                    else:
                        assert isinstance(
                            attn_metadata.prefill_wrapper,
                            BatchPrefillWithPagedKVCacheWrapper,
                        )
                        attn_metadata.prefill_wrapper.plan(
                            qo_indptr_cpu,
                            paged_kv_indptr_cpu,
                            paged_kv_indices,
                            paged_kv_last_page_len_cpu[prefill_start:],
                            self.num_qo_heads,
                            self.num_kv_heads,
                            self.head_dim,
                            self.page_size,
                            causal=True,
                            sm_scale=self.sm_scale,
                            window_left=self.window_left,
                            logits_soft_cap=self.logits_soft_cap,
                            q_data_type=self.q_data_type,
                            kv_data_type=self.kv_cache_dtype,
                            fixed_split_size=self.prefill_fixed_split_size,
                            disable_split_kv=self.disable_split_kv,
                        )
                else:
                    attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
                        self.device, non_blocking=True
                    )
                    attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
                        self.device, non_blocking=True
                    )

            if num_decodes > 0:
                pure_decode = num_prefills == 0
                # possible required padding for cudagraph replay
                use_cudagraph = (
                    self.enable_cuda_graph
                    and pure_decode
                    and num_decode_tokens <= self._decode_cudagraph_max_bs
                )
                if use_cudagraph:
                    num_input_tokens = self.vllm_config.pad_for_cudagraph(
                        num_decode_tokens
                    )
                    # Carefully fulfill the padding region with reasonable value
                    # on cpu.
                    # Make sure paged_kv_indptr_cpu is not decreasing
                    self.paged_kv_indptr_cpu[
                        1 + num_decodes : 1 + num_input_tokens
                    ].fill_(paged_kv_indptr_cpu[-1])
                    # Fill the remaining paged_kv_last_page_len_cpu with 1.
                    # This is because flashinfer treats 0 as a full page
                    # instead of empty.
                    self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
                        1
                    )

                else:
                    num_input_tokens = num_decode_tokens

                attn_metadata.decode_wrapper = self._get_decode_wrapper(
                    num_input_tokens, use_cudagraph
                )
                if not attn_metadata.decode_use_trtllm:
                    # Use the persistent buffer with padding length,
                    # instead of the same address but chunked version
                    # in atten_metadata when using cudagraph.
                    fast_plan_decode(
                        attn_metadata.decode_wrapper,
                        self.paged_kv_indptr_cpu[: num_input_tokens + 1],
                        paged_kv_indices,
                        self.paged_kv_last_page_len_cpu[:num_input_tokens],
                        seq_lens_cpu[:num_input_tokens],
                        self.num_qo_heads * self.dcp_world_size,
                        self.num_kv_heads,
                        self.head_dim,
                        self.page_size,
                        # Disable flashinfer's pos encoding and use vllm's rope.
                        pos_encoding_mode="NONE",
                        sm_scale=self.sm_scale,
                        window_left=self.window_left,
                        logits_soft_cap=self.logits_soft_cap,
                        q_data_type=self.q_data_type,
                        kv_data_type=self.kv_cache_dtype,
                        fixed_split_size=self.decode_fixed_split_size,
                        disable_split_kv=self.disable_split_kv,
                    )
        return attn_metadata

    def use_cascade_attention(self, *args, **kwargs) -> bool:
        if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
            # TODO: The cascade wrapper currently does not support setting
            # kv cache dtype to something different from query dtype.
            return False
        # TODO: Cascade attention doesn't work, disable it for now
        # return use_cascade_attention(*args, **kwargs)
        return False

_cascade_wrapper instance-attribute

_cascade_wrapper = None

_decode_cudagraph_max_bs instance-attribute

_decode_cudagraph_max_bs = min(
    (1 + num_spec_tokens) * max_num_reqs,
    max_cudagraph_capture_size,
)

_decode_wrapper instance-attribute

_decode_wrapper = None

_decode_wrappers_cudagraph instance-attribute

_decode_wrappers_cudagraph: dict[
    int, BatchDecodeWithPagedKVCacheWrapper
] = {}

_prefill_wrapper instance-attribute

_prefill_wrapper: (
    BatchPrefillWithPagedKVCacheWrapper
    | BatchDCPPrefillWrapper
    | None
) = None

_workspace_buffer instance-attribute

_workspace_buffer = None

cache_config instance-attribute

cache_config = cache_config

cache_dtype instance-attribute

cache_dtype = cache_dtype

compilation_config instance-attribute

compilation_config = compilation_config

dcp_kv_cache_interleave_size instance-attribute

dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size

dcp_rank instance-attribute

dcp_rank = rank_in_group

dcp_world_size instance-attribute

dcp_world_size = world_size

decode_fixed_split_size instance-attribute

decode_fixed_split_size = 2048

disable_split_kv instance-attribute

disable_split_kv = True

enable_cuda_graph instance-attribute

enable_cuda_graph = decode_mode() == FULL

global_hyperparameters instance-attribute

global_hyperparameters = infer_global_hyperparameters(
    get_per_layer_parameters(
        vllm_config, layer_names, FlashInferImpl
    )
)

has_sinks instance-attribute

has_sinks = has_sinks

head_dim instance-attribute

head_dim = head_size

kv_cache_dtype instance-attribute

kv_cache_dtype = get_fp8_dtype_for_flashinfer(cache_dtype)

logits_soft_cap instance-attribute

logits_soft_cap = logits_soft_cap

model_config instance-attribute

model_config = model_config

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

num_qo_heads instance-attribute

num_qo_heads = (
    get_num_attention_heads(parallel_config)
    * dcp_world_size
)

page_size instance-attribute

page_size = block_size

paged_kv_indices instance-attribute

paged_kv_indices = zeros(
    max_num_pages, dtype=int32, device=device
)

paged_kv_indices_cpu instance-attribute

paged_kv_indices_cpu = zeros(
    max_num_pages,
    dtype=int32,
    device="cpu",
    pin_memory=pin_memory,
)

paged_kv_indptr instance-attribute

paged_kv_indptr = zeros(
    max_num_reqs + 1, dtype=int32, device=device
)

paged_kv_indptr_buffer instance-attribute

paged_kv_indptr_buffer = zeros_like(
    paged_kv_indptr_cpu, pin_memory=pin_memory
)

paged_kv_indptr_cpu instance-attribute

paged_kv_indptr_cpu = zeros(
    max_num_reqs + 1,
    dtype=int32,
    device="cpu",
    pin_memory=pin_memory,
)

paged_kv_indptr_np instance-attribute

paged_kv_indptr_np = numpy()

paged_kv_last_page_len instance-attribute

paged_kv_last_page_len = zeros(
    max_num_reqs, dtype=int32, device=device
)

paged_kv_last_page_len_cpu instance-attribute

paged_kv_last_page_len_cpu = zeros(
    max_num_reqs,
    dtype=int32,
    device="cpu",
    pin_memory=pin_memory,
)

paged_kv_last_page_len_np instance-attribute

paged_kv_last_page_len_np = numpy()

prefill_fixed_split_size instance-attribute

prefill_fixed_split_size = 4096

q_data_type instance-attribute

q_data_type = kv_cache_dtype

reorder_batch_threshold class-attribute instance-attribute

reorder_batch_threshold: int = 1

sm_scale instance-attribute

sm_scale = sm_scale

use_trtllm_decode_attention instance-attribute

use_trtllm_decode_attention = can_use_trtllm

window_left instance-attribute

window_left = window_left

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/flashinfer.py
def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: torch.device,
):
    super().__init__(kv_cache_spec, layer_names, vllm_config, device)
    self.cache_config = vllm_config.cache_config
    self.model_config = vllm_config.model_config
    self._workspace_buffer = None
    self._prefill_wrapper: (
        BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
    ) = None  # Wrapper for prefill/append
    self._decode_wrapper = None  # Wrapper for decode (general shape)

    if vllm_is_batch_invariant():
        self.decode_fixed_split_size = 2048
        self.prefill_fixed_split_size = 4096
        self.disable_split_kv = True
    else:
        self.decode_fixed_split_size = -1
        self.prefill_fixed_split_size = -1
        self.disable_split_kv = False

    self.compilation_config = vllm_config.compilation_config
    max_num_pages_per_req = cdiv(
        self.model_config.max_model_len, self.kv_cache_spec.block_size
    )
    max_num_reqs = vllm_config.scheduler_config.max_num_seqs
    max_num_pages = max_num_reqs * max_num_pages_per_req
    speculative_config = vllm_config.speculative_config
    num_spec_tokens = (
        speculative_config.num_speculative_tokens
        if speculative_config is not None
        else 0
    )
    self.enable_cuda_graph = (
        self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
    )
    if self.enable_cuda_graph:
        # For full cudagraph capture, one `decode_wrapper` for each batch
        # size is needed for FlashInfer.
        self._decode_wrappers_cudagraph: dict[
            int, BatchDecodeWithPagedKVCacheWrapper
        ] = {}
        self._decode_cudagraph_max_bs = min(
            (1 + num_spec_tokens) * max_num_reqs,
            self.compilation_config.max_cudagraph_capture_size,
        )

    try:
        self.dcp_world_size = get_dcp_group().world_size
        self.dcp_rank = get_dcp_group().rank_in_group
        self.dcp_kv_cache_interleave_size = (
            vllm_config.parallel_config.dcp_kv_cache_interleave_size
        )
    except AssertionError:
        # DCP might not be initialized in testing
        self.dcp_world_size = 1
        self.dcp_rank = 0
        self.dcp_kv_cache_interleave_size = 1

    self.num_qo_heads = (
        self.model_config.get_num_attention_heads(self.vllm_config.parallel_config)
        * self.dcp_world_size
    )

    self.num_kv_heads = self.kv_cache_spec.num_kv_heads
    self.head_dim = self.kv_cache_spec.head_size
    self.page_size = self.kv_cache_spec.block_size

    self.cache_dtype = self.cache_config.cache_dtype
    if self.cache_dtype.startswith("fp8"):
        self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
            self.cache_dtype
        )
    else:
        assert self.kv_cache_spec.dtype == self.model_config.dtype
        self.kv_cache_dtype = self.kv_cache_spec.dtype

    # Use model dtype as q dtype when TRTLLM attn is not supported, or
    # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
    # use fp8 q if kv cache is fp8, and will fall back to model dtype
    # if TRTLLM attention kernel is not used when building attn metadata
    can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
    if can_use_trtllm and not flashinfer_disable_q_quantization():
        self.q_data_type = self.kv_cache_dtype
    else:
        self.q_data_type = self.model_config.dtype

    # Prefer TRTLLM attention for decoding in all cases.
    # This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
    self.use_trtllm_decode_attention = can_use_trtllm
    self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm)

    self._cascade_wrapper = None  # Wrapper for cascade attention

    # Global hyperparameters shared by all attention layers
    # TODO: discard this for trtllm-gen backend
    self.global_hyperparameters = infer_global_hyperparameters(
        get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)
    )
    self.sm_scale = self.global_hyperparameters.sm_scale
    self.window_left = self.global_hyperparameters.window_left
    self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
    self.has_sinks = self.global_hyperparameters.has_sinks
    if self.has_sinks and not can_use_trtllm:
        raise NotImplementedError(
            "FlashInfer backend currently does not support attention "
            "sinks, please use trtllm on blackwell or flash attention on "
            "earlier GPUs."
        )
    # Preparing persistent buffers (device-side)
    self.paged_kv_indptr = torch.zeros(
        max_num_reqs + 1, dtype=torch.int32, device=self.device
    )
    self.paged_kv_indices = torch.zeros(
        max_num_pages,  # max num pages possible
        dtype=torch.int32,
        device=self.device,
    )
    self.paged_kv_last_page_len = torch.zeros(
        max_num_reqs, dtype=torch.int32, device=self.device
    )
    # host-side buffer
    pin_memory = is_pin_memory_available()
    self.paged_kv_indptr_cpu = torch.zeros(
        max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory
    )
    self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
    self.paged_kv_indptr_buffer = torch.zeros_like(
        self.paged_kv_indptr_cpu, pin_memory=pin_memory
    )
    self.paged_kv_indices_cpu = torch.zeros(
        max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory
    )
    self.paged_kv_last_page_len_cpu = torch.zeros(
        max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory
    )
    self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()

    if self.head_dim == 256 and current_platform.is_device_capability(100):
        # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
        # head size 256 and block size 16 is not supported on blackwell.
        assert kv_cache_spec.block_size != 16, (
            "There is a bug in FlashInfer "
            "block_size 16 head size 256 support. Please avoid this combination by "
            "passing --block-size 32 or --block-size 64."
        )

_get_cascade_wrapper

_get_cascade_wrapper()
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_cascade_wrapper(self):
    if self._cascade_wrapper is None:
        self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
            2, self._get_workspace_buffer(), get_kv_cache_layout()
        )
    return self._cascade_wrapper

_get_decode_wrapper

_get_decode_wrapper(
    batch_size: int, use_cudagraph: bool = False
)
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False):
    if use_cudagraph:
        decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None)
    else:
        decode_wrapper = self._decode_wrapper

    if decode_wrapper is None:
        if use_cudagraph:
            paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1]
            paged_kv_indices = self.paged_kv_indices
            paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size]
        else:
            paged_kv_indptr = None
            paged_kv_indices = None
            paged_kv_last_page_len = None
        decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
            self._get_workspace_buffer(),
            get_kv_cache_layout(),
            use_cuda_graph=use_cudagraph,
            paged_kv_indptr_buffer=paged_kv_indptr,
            paged_kv_indices_buffer=paged_kv_indices,
            paged_kv_last_page_len_buffer=paged_kv_last_page_len,
            # Tensor cores are enabled by default because the perf would be
            # at least as good as cuda cores for all attention ops in latest
            # gpus.
            use_tensor_cores=True,
        )

        # save the decode wrapper
        if use_cudagraph:
            self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
        else:
            self._decode_wrapper = decode_wrapper

    return decode_wrapper

_get_prefill_wrapper

_get_prefill_wrapper() -> (
    BatchPrefillWithPagedKVCacheWrapper
    | BatchDCPPrefillWrapper
)
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_prefill_wrapper(
    self,
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
    if self._prefill_wrapper is None:
        if self.dcp_world_size > 1:
            self._prefill_wrapper = BatchDCPPrefillWrapper(
                workspace_buffer=self._get_workspace_buffer(),
            )
        else:
            self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
                self._get_workspace_buffer(), get_kv_cache_layout()
            )
    assert self._prefill_wrapper is not None
    return self._prefill_wrapper

_get_workspace_buffer

_get_workspace_buffer()
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_workspace_buffer(self):
    if self._workspace_buffer is None:
        buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
        if vllm_is_batch_invariant():
            buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
        self._workspace_buffer = torch.zeros(
            buffer_size, dtype=torch.uint8, device=self.device
        )
    return self._workspace_buffer

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> FlashInferMetadata
Source code in vllm/v1/attention/backends/flashinfer.py
def build(
    self,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> FlashInferMetadata:
    num_reqs = common_attn_metadata.num_reqs
    num_actual_tokens = common_attn_metadata.num_actual_tokens
    num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
        split_decodes_and_prefills(
            common_attn_metadata,
            decode_threshold=self.reorder_batch_threshold,
            require_uniform=True,
        )
    )

    page_size = self.page_size
    max_q_len = common_attn_metadata.max_query_len
    max_seq_len = common_attn_metadata.max_seq_len
    seq_lens = common_attn_metadata.seq_lens
    seq_lens_cpu = common_attn_metadata.seq_lens_cpu
    block_table_tensor = common_attn_metadata.block_table_tensor
    qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu

    if self.dcp_world_size > 1:
        if num_prefills > 0:
            qo_indptr_prefill_cpu = (
                qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
            )
            query_lens_prefill_cpu = (
                qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
            )
            seq_lens_cpu[num_decodes:] = (
                seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
            )

        seq_lens_cpu = get_dcp_local_seq_lens(
            seq_lens_cpu,
            self.dcp_world_size,
            self.dcp_rank,
            self.dcp_kv_cache_interleave_size,
        )

    seq_lens_np = seq_lens_cpu.numpy()
    num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size

    use_cascade = common_prefix_len > 0
    if use_cascade:
        # Grab the blocks of the shared prefix from the first request.
        assert common_prefix_len % page_size == 0
        num_common_kv_blocks = common_prefix_len // page_size

        # Create CPU versions directly for cascade (no GPU versions needed)
        shared_qo_indptr_cpu = torch.tensor(
            [0, num_actual_tokens], dtype=torch.int32, device="cpu"
        )
        shared_kv_page_indptr_cpu = torch.tensor(
            [0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
        )
        shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
        shared_kv_last_page_len_cpu = torch.tensor(
            [page_size], dtype=torch.int32, device="cpu"
        )

        # Remove the blocks of the shared prefix from all requests.
        block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
        num_blocks_np -= num_common_kv_blocks
    else:
        shared_qo_indptr_cpu = None
        shared_kv_page_indptr_cpu = None
        shared_kv_page_indices_cpu = None
        shared_kv_last_page_len_cpu = None

    # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
    np.cumsum(
        num_blocks_np,
        dtype=np.int32,
        out=self.paged_kv_indptr_np[1 : num_reqs + 1],
    )
    # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
    # after this line (e.g., for cuda graphs), we need to copy the data to
    # self.paged_kv_indptr_buffer to avoid race condition.
    self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[
        : num_reqs + 1
    ]
    paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1]
    paged_kv_indptr.copy_(
        self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True
    )

    # write self.paged_kv_indices inplace
    num_actual_pages = self.paged_kv_indptr_np[num_reqs]
    paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
    _copy_page_indices_kernel[(num_reqs,)](
        paged_kv_indices,
        block_table_tensor,
        block_table_tensor.stride(0),
        paged_kv_indptr,
        BLOCK_SIZE=1024,
    )

    # write self.paged_kv_last_page_len_cpu inplace
    paged_kv_last_page_len_np = seq_lens_np % page_size
    self.paged_kv_last_page_len_np[:num_reqs] = np.where(
        (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
        page_size,
        paged_kv_last_page_len_np,
    )

    uses_spec_reorder = self.reorder_batch_threshold > 1
    prefill_use_trtllm = use_trtllm_attention(
        self.num_qo_heads,
        self.num_kv_heads,
        num_prefill_tokens,
        max_seq_len,
        self.dcp_world_size,
        self.cache_dtype,
        self.q_data_type,
        is_prefill=True,
        has_sinks=self.has_sinks,
        has_spec=uses_spec_reorder,
    )
    decode_use_trtllm = (
        self.use_trtllm_decode_attention and self.dcp_world_size <= 1
    )

    if not (prefill_use_trtllm and decode_use_trtllm):
        if self.has_sinks:
            raise NotImplementedError(
                "FlashInfer backend currently does not support attention "
                "sinks, please use trtllm on blackwell or flash attention "
                "on earlier GPUs."
            )

        if not self.global_hyperparameters.has_same_window_lefts:
            raise ValueError(
                "Window left is not the same for all layers. "
                "One potential fix is to set disable_sliding_window=True"
            )

        assert self.global_hyperparameters.has_same_all_params, (
            "FlashInfer backend currently only supports models in which "
            "all layers share the same values for the following "
            "hyperparameters: `window_left`, `logits_soft_cap`, "
            "`sm_scale`."
        )

        # The q quantization is not supported for non-trtllm attention,
        # fall back to model dtype.
        self.q_data_type = self.model_config.dtype

    attn_metadata = FlashInferMetadata(
        num_actual_tokens=num_actual_tokens,
        q_data_type=self.q_data_type,
        slot_mapping=common_attn_metadata.slot_mapping,
        max_q_len=max_q_len,
        max_q_len_prefill=max_q_len,
        max_seq_len=max_seq_len,
        seq_lens=seq_lens,
        block_table_tensor=block_table_tensor,
        prefill_use_trtllm=prefill_use_trtllm,
        decode_use_trtllm=decode_use_trtllm,
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
        num_prefills=num_prefills,
        num_prefill_tokens=num_prefill_tokens,
        use_cascade=use_cascade,
    )

    paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs]
    paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]

    if attn_metadata.use_cascade:
        attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
        attn_metadata.cascade_wrapper.plan(
            [shared_qo_indptr_cpu, qo_indptr_cpu],
            [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
            [shared_kv_page_indices_cpu, paged_kv_indices],
            [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu],
            self.num_qo_heads,
            self.num_kv_heads,
            self.head_dim,
            self.page_size,
            causal=True,
            sm_scale=self.sm_scale,
            window_left=self.window_left,
            logits_soft_cap=self.logits_soft_cap,
            q_data_type=self.q_data_type,
            kv_data_type=self.kv_cache_dtype,
        )
    else:
        # Regular attention (common case).
        # Decodes are at the front and prefills are at the back.
        num_prefills = attn_metadata.num_prefills
        num_decodes = attn_metadata.num_decodes
        if num_prefills > 0:
            # Decodes are first so prefills start after the last decode
            prefill_start = num_decodes
            attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
            assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
            assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
            assert (
                paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills
            )
            # Since prefill_wrapper.run() will be called with
            # query[num_decode_tokens:] we need to adjust the qo_indptr
            # to be relative to the start of the prefill queries.
            qo_indptr_cpu = (
                qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
            )
            paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]

            # Recompute max_q_len for the slice of requests we are using
            # for prefills. This can be different from max_q_len when
            # we have a non-uniform batch with some short decodes offloaded
            # to the prefill pathway
            query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
            attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item())

            if not attn_metadata.prefill_use_trtllm:
                if self.dcp_world_size > 1:
                    assert isinstance(
                        attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper
                    )
                    attn_metadata.prefill_wrapper.plan(
                        qo_indptr_cpu=qo_indptr_cpu,
                        paged_kv_indptr_cpu=paged_kv_indptr_cpu,
                        paged_kv_indices=paged_kv_indices,
                        paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
                        prefill_start=prefill_start,
                        page_size=self.page_size,
                        num_qo_heads=self.num_qo_heads,
                        dcp_world_size=self.dcp_world_size,
                        num_kv_heads=self.num_kv_heads,
                        head_dim=self.head_dim,
                        sm_scale=self.sm_scale,
                        window_left=self.window_left,
                        logits_soft_cap=self.logits_soft_cap,
                        q_data_type=self.q_data_type,
                        kv_cache_dtype=self.kv_cache_dtype,
                        prefill_fixed_split_size=self.prefill_fixed_split_size,
                        disable_split_kv=self.disable_split_kv,
                    )
                else:
                    assert isinstance(
                        attn_metadata.prefill_wrapper,
                        BatchPrefillWithPagedKVCacheWrapper,
                    )
                    attn_metadata.prefill_wrapper.plan(
                        qo_indptr_cpu,
                        paged_kv_indptr_cpu,
                        paged_kv_indices,
                        paged_kv_last_page_len_cpu[prefill_start:],
                        self.num_qo_heads,
                        self.num_kv_heads,
                        self.head_dim,
                        self.page_size,
                        causal=True,
                        sm_scale=self.sm_scale,
                        window_left=self.window_left,
                        logits_soft_cap=self.logits_soft_cap,
                        q_data_type=self.q_data_type,
                        kv_data_type=self.kv_cache_dtype,
                        fixed_split_size=self.prefill_fixed_split_size,
                        disable_split_kv=self.disable_split_kv,
                    )
            else:
                attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
                    self.device, non_blocking=True
                )
                attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
                    self.device, non_blocking=True
                )

        if num_decodes > 0:
            pure_decode = num_prefills == 0
            # possible required padding for cudagraph replay
            use_cudagraph = (
                self.enable_cuda_graph
                and pure_decode
                and num_decode_tokens <= self._decode_cudagraph_max_bs
            )
            if use_cudagraph:
                num_input_tokens = self.vllm_config.pad_for_cudagraph(
                    num_decode_tokens
                )
                # Carefully fulfill the padding region with reasonable value
                # on cpu.
                # Make sure paged_kv_indptr_cpu is not decreasing
                self.paged_kv_indptr_cpu[
                    1 + num_decodes : 1 + num_input_tokens
                ].fill_(paged_kv_indptr_cpu[-1])
                # Fill the remaining paged_kv_last_page_len_cpu with 1.
                # This is because flashinfer treats 0 as a full page
                # instead of empty.
                self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
                    1
                )

            else:
                num_input_tokens = num_decode_tokens

            attn_metadata.decode_wrapper = self._get_decode_wrapper(
                num_input_tokens, use_cudagraph
            )
            if not attn_metadata.decode_use_trtllm:
                # Use the persistent buffer with padding length,
                # instead of the same address but chunked version
                # in atten_metadata when using cudagraph.
                fast_plan_decode(
                    attn_metadata.decode_wrapper,
                    self.paged_kv_indptr_cpu[: num_input_tokens + 1],
                    paged_kv_indices,
                    self.paged_kv_last_page_len_cpu[:num_input_tokens],
                    seq_lens_cpu[:num_input_tokens],
                    self.num_qo_heads * self.dcp_world_size,
                    self.num_kv_heads,
                    self.head_dim,
                    self.page_size,
                    # Disable flashinfer's pos encoding and use vllm's rope.
                    pos_encoding_mode="NONE",
                    sm_scale=self.sm_scale,
                    window_left=self.window_left,
                    logits_soft_cap=self.logits_soft_cap,
                    q_data_type=self.q_data_type,
                    kv_data_type=self.kv_cache_dtype,
                    fixed_split_size=self.decode_fixed_split_size,
                    disable_split_kv=self.disable_split_kv,
                )
    return attn_metadata

get_cudagraph_support classmethod

get_cudagraph_support(
    vllm_config: VllmConfig, kv_cache_spec: AttentionSpec
) -> AttentionCGSupport
Source code in vllm/v1/attention/backends/flashinfer.py
@classmethod
@override
def get_cudagraph_support(
    cls: type["FlashInferMetadataBuilder"],
    vllm_config: VllmConfig,
    kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
    has_trtllm_support = can_use_trtllm_attention(
        num_qo_heads=vllm_config.model_config.get_num_attention_heads(
            vllm_config.parallel_config
        ),
        num_kv_heads=kv_cache_spec.num_kv_heads,
    )
    if has_trtllm_support:
        return AttentionCGSupport.UNIFORM_BATCH
    else:
        return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE

use_cascade_attention

use_cascade_attention(*args, **kwargs) -> bool
Source code in vllm/v1/attention/backends/flashinfer.py
def use_cascade_attention(self, *args, **kwargs) -> bool:
    if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
        # TODO: The cascade wrapper currently does not support setting
        # kv cache dtype to something different from query dtype.
        return False
    # TODO: Cascade attention doesn't work, disable it for now
    # return use_cascade_attention(*args, **kwargs)
    return False

_copy_page_indices_kernel

_copy_page_indices_kernel(
    page_indices,
    block_table,
    block_table_stride,
    cu_num_blocks,
    BLOCK_SIZE: constexpr,
)
Source code in vllm/v1/attention/backends/flashinfer.py
@triton.jit
def _copy_page_indices_kernel(
    page_indices,
    block_table,
    block_table_stride,
    cu_num_blocks,
    BLOCK_SIZE: tl.constexpr,
):
    req_idx = tl.program_id(0)
    row_ptr = block_table + req_idx * block_table_stride
    start_idx = tl.load(cu_num_blocks + req_idx)
    end_idx = tl.load(cu_num_blocks + req_idx + 1)
    num_blocks = end_idx - start_idx

    offset = tl.arange(0, BLOCK_SIZE)
    for i in tl.range(0, num_blocks, BLOCK_SIZE):
        block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks)
        tl.store(
            page_indices + start_idx + i + offset,
            block_ids,
            mask=i + offset < num_blocks,
        )

_get_trtllm_gen_workspace_buffer

_get_trtllm_gen_workspace_buffer()
Source code in vllm/v1/attention/backends/flashinfer.py
def _get_trtllm_gen_workspace_buffer():
    global trtllm_gen_workspace_buffer
    if trtllm_gen_workspace_buffer is None:
        trtllm_gen_workspace_buffer = torch.zeros(
            envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda"
        )
    return trtllm_gen_workspace_buffer

_trtllm_prefill_attn_kvfp8_dequant

_trtllm_prefill_attn_kvfp8_dequant(
    kv_cache_ptr,
    block_tables_prefill_ptr,
    block_table_stride,
    mock_kv_cache_ptr,
    k_scale_ptr,
    v_scale_ptr,
    K_CACHE_STRIDE: constexpr,
    KV_CACHE_STRIDE: constexpr,
)
Source code in vllm/v1/attention/backends/flashinfer.py
@triton.jit
def _trtllm_prefill_attn_kvfp8_dequant(
    kv_cache_ptr,
    block_tables_prefill_ptr,
    block_table_stride,
    mock_kv_cache_ptr,
    k_scale_ptr,
    v_scale_ptr,
    K_CACHE_STRIDE: tl.constexpr,
    KV_CACHE_STRIDE: tl.constexpr,
):
    batch_idx = tl.program_id(0).to(tl.int64)
    mock_block_table_idx = tl.program_id(1).to(tl.int64)
    orig_page_num = tl.load(
        block_tables_prefill_ptr + batch_idx * block_table_stride + mock_block_table_idx
    ).to(tl.int64)
    if orig_page_num <= 0:
        return
    dequant_dtype = mock_kv_cache_ptr.dtype.element_ty

    # Dequantize K
    k_scale_val = tl.load(k_scale_ptr)
    offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
    fp8_vals = tl.load(kv_cache_ptr + offset)
    dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
    mock_cache_offset = (
        batch_idx * block_table_stride + mock_block_table_idx + 1
    ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
    dequantized_vals = dequantized_vals.to(dequant_dtype)
    tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)

    # Dequantize V
    v_scale_val = tl.load(v_scale_ptr)
    offset = (
        orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
    )
    fp8_vals = tl.load(kv_cache_ptr + offset)
    dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
    mock_cache_offset = (
        (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE
        + K_CACHE_STRIDE
        + tl.arange(0, K_CACHE_STRIDE)
    )
    dequantized_vals = dequantized_vals.to(dequant_dtype)
    tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)

fast_plan_decode

fast_plan_decode(
    self,
    indptr_cpu: Tensor,
    indices: Tensor,
    last_page_len_cpu: Tensor,
    seq_lens_cpu: Tensor,
    num_qo_heads: int,
    num_kv_heads: int,
    head_dim: int,
    page_size: int,
    pos_encoding_mode: str = "NONE",
    window_left: int = -1,
    logits_soft_cap: float | None = None,
    q_data_type: str | dtype | None = "float16",
    kv_data_type: str | dtype | None = None,
    data_type: str | dtype | None = None,
    sm_scale: float | None = None,
    rope_scale: float | None = None,
    rope_theta: float | None = None,
    non_blocking: bool = True,
    fixed_split_size: int = -1,
    disable_split_kv: bool = False,
) -> None

A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for cudagraph capture/replay, while the no cudagraph version turns back to the original plan. using original plan after passing host-side buffers: - only host-to-device copy of indptr and last_page_len buffers Modifications for cudagraph: - only host-to-device copy of indptr and last_page_len buffers. - avoid device-to-device copy of indices buffer.

Part of the code get inspiration from the original plan from FlashInfer repo and the implementation of fast_decode_plan for FlashInfer in SGlang repo.

Source code in vllm/v1/attention/backends/flashinfer.py
def fast_plan_decode(
    self,  # decode wrapper
    indptr_cpu: torch.Tensor,
    indices: torch.Tensor,
    last_page_len_cpu: torch.Tensor,
    seq_lens_cpu: torch.Tensor,
    num_qo_heads: int,
    num_kv_heads: int,
    head_dim: int,
    page_size: int,
    pos_encoding_mode: str = "NONE",
    window_left: int = -1,
    logits_soft_cap: float | None = None,
    q_data_type: str | torch.dtype | None = "float16",
    kv_data_type: str | torch.dtype | None = None,
    data_type: str | torch.dtype | None = None,
    sm_scale: float | None = None,
    rope_scale: float | None = None,
    rope_theta: float | None = None,
    non_blocking: bool = True,
    fixed_split_size: int = -1,
    disable_split_kv: bool = False,
) -> None:
    """
    A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
    cudagraph capture/replay, while the no cudagraph version turns back
    to the original plan.
    using original plan after passing host-side buffers:
    - only host-to-device copy of indptr and last_page_len buffers
    Modifications for cudagraph:
    - only host-to-device copy of indptr and last_page_len buffers.
    - avoid device-to-device copy of indices buffer.

    Part of the code get inspiration from the original plan from FlashInfer repo
    and the implementation of fast_decode_plan for FlashInfer in SGlang repo.
    """
    # Warm up with the original plan if it is first call, and always run the
    # original plan if we run for dynamic shape. For fixed shape (cudagraph),
    # this warm up is to generate the _cached_module for the decode wrapper.
    if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
        self.plan(
            indptr_cpu,
            indices,
            last_page_len_cpu,
            num_qo_heads,
            num_kv_heads,
            head_dim,
            page_size,
            pos_encoding_mode,
            window_left,
            logits_soft_cap,
            q_data_type,
            kv_data_type,
            data_type,
            sm_scale,
            rope_scale,
            rope_theta,
            non_blocking,
            None,  # block_tables
            None,  # seq_lens
            fixed_split_size,
            disable_split_kv,
        )
        self.vllm_first_call = False
        return

    assert self.is_cuda_graph_enabled, "Should be cudagraph only here"

    batch_size = len(last_page_len_cpu)
    if logits_soft_cap is None:
        logits_soft_cap = 0.0

    # Handle data types consistently
    if data_type is not None:
        if q_data_type is None:
            q_data_type = data_type
        if kv_data_type is None:
            kv_data_type = data_type
    elif q_data_type is None:
        q_data_type = "float16"

    if kv_data_type is None:
        kv_data_type = q_data_type
    q_data_type = (
        getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
    )
    kv_data_type = (
        getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type
    )

    if batch_size != self._fixed_batch_size:
        raise ValueError(
            "The batch size should be fixed in cudagraph mode, the runtime "
            "batch size {} mismatches the batch size set during "
            "initialization {}".format(batch_size, self._fixed_batch_size)
        )
    if len(indices) > len(self._paged_kv_indices_buf):
        raise ValueError(
            "The size of indices should be less than or equal to the allocated buffer"
        )

    # host-to-device copy for the indptr buffer
    self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
    # host-to-device copy for the last_page_len buffer
    self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True)

    qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")

    try:
        # Make sure we pass exactly 18 arguments for tensor core version
        self._plan_info = self._cached_module.plan(
            self._float_workspace_buffer,
            self._int_workspace_buffer,
            self._pin_memory_int_workspace_buffer,
            qo_indptr_host,
            indptr_cpu,
            seq_lens_cpu,
            batch_size,  # total_num_rows
            batch_size,
            num_qo_heads,
            num_kv_heads,
            page_size,
            self.is_cuda_graph_enabled,
            head_dim,
            head_dim,
            False,  # causal
            window_left,
            fixed_split_size,
            disable_split_kv,
        )
    except Exception as e:
        raise RuntimeError(f"Error in tensor core plan: {e}") from e

    self._pos_encoding_mode = pos_encoding_mode
    self._window_left = window_left
    self._logits_soft_cap = logits_soft_cap
    self._sm_scale = sm_scale
    self._rope_scale = rope_scale
    self._rope_theta = rope_theta

trtllm_prefill_attn_kvfp8_dequant

trtllm_prefill_attn_kvfp8_dequant(
    kv_cache: Tensor,
    block_tables_prefill: Tensor,
    k_scale: Tensor,
    v_scale: Tensor,
    dequant_dtype: dtype,
) -> tuple[Tensor, Tensor]
Source code in vllm/v1/attention/backends/flashinfer.py
def trtllm_prefill_attn_kvfp8_dequant(
    kv_cache: torch.Tensor,
    block_tables_prefill: torch.Tensor,
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
    dequant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
    batch_size, num_of_page_per_token = block_tables_prefill.shape
    s = kv_cache.shape
    assert s[1] == 2
    assert dequant_dtype in (torch.bfloat16, torch.float16)
    k_cache_stride = s[2] * s[3] * s[4]
    kv_cache_stride = k_cache_stride * s[1]
    new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
    # mock kv cache contains just the pages needed by this prefill
    mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device)
    # we simply sequentially index the pages needed by this prefill
    mock_block_table = torch.arange(
        start=1,
        end=batch_size * num_of_page_per_token + 1,
        dtype=torch.int32,
        device=block_tables_prefill.device,
    ).reshape(batch_size, num_of_page_per_token)
    grid = (batch_size, num_of_page_per_token)
    _trtllm_prefill_attn_kvfp8_dequant[grid](
        kv_cache,
        block_tables_prefill,
        num_of_page_per_token,
        mock_kv_cache,
        k_scale,
        v_scale,
        k_cache_stride,
        kv_cache_stride,
    )
    return mock_kv_cache, mock_block_table