vllm.compilation.sequence_parallelism ¶
FirstAllReduceRMSNormPattern ¶
Bases: _SequenceParallelPatternHelper
Source code in vllm/compilation/sequence_parallelism.py
__init__ ¶
get_inputs ¶
register ¶
Source code in vllm/compilation/sequence_parallelism.py
FirstAllReduceRMSNormStaticFP8Pattern ¶
Bases: _SequenceParallelPatternHelper
Source code in vllm/compilation/sequence_parallelism.py
__init__ ¶
Source code in vllm/compilation/sequence_parallelism.py
get_inputs ¶
Source code in vllm/compilation/sequence_parallelism.py
register ¶
Source code in vllm/compilation/sequence_parallelism.py
MiddleAllReduceRMSNormPattern ¶
Bases: _SequenceParallelPatternHelper
Source code in vllm/compilation/sequence_parallelism.py
__init__ ¶
get_inputs ¶
Source code in vllm/compilation/sequence_parallelism.py
register ¶
Source code in vllm/compilation/sequence_parallelism.py
MiddleAllReduceRMSNormStaticFP8Pattern ¶
Bases: _SequenceParallelPatternHelper
Source code in vllm/compilation/sequence_parallelism.py
__init__ ¶
get_inputs ¶
Source code in vllm/compilation/sequence_parallelism.py
register ¶
Source code in vllm/compilation/sequence_parallelism.py
SequenceParallelismPass ¶
Bases: VllmPatternMatcherPass
This pass enables sequence parallelism for models. It identifies patterns where an AllReduce operation is followed by an RMSNorm (or RMSNorm and then Quantization) operation. These patterns are replaced with a ReduceScatter operation, followed by a local RMSNorm/Quantization, and then an AllGather operation.
The general transformation is: Input -> AllReduce -> RMSNorm -> Output becomes Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
While this pass itself does not directly yield performance improvements, it lays the groundwork for subsequent fusion passes, such as GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can significantly reduce communication overhead and improve overall model performance.
This pass splits up the residual tensor across TP ranks and hence divides its size. Because the pattern matcher starts at the end of the graph, the replacement contains a slice that temporarily conforms the input residual to the correct size. After all patterns have been matched, we use a NoOpEliminationPass to clean up what have now become no-op slices.
Note that an older version of the pass did not need this as it operated only on custom rms_norm and fused_rms_norm_add custom ops which did not complain about mismatched shapes during replacement. So this approach has the same assumption that correctness is only maintained if all rms_norm operations are split across ranks.
Correctness-wise, this is approach strictly better than before - before, the graph was incorrect semantically and shape-wise during the pass. With this approach there's only semantic incorrectness during the pass. Both approaches restore a correct graph once all patterns are matched.
Source code in vllm/compilation/sequence_parallelism.py
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 | |
patterns instance-attribute ¶
__init__ ¶
__init__(config: VllmConfig)
Source code in vllm/compilation/sequence_parallelism.py
is_applicable ¶
Source code in vllm/compilation/sequence_parallelism.py
_SequenceParallelPatternHelper ¶
Helper for sequence parallelism patterns.