malaysia-ai/Flash-Attention3-wheel
收藏Hugging Face2025-11-19 更新2026-01-03 收录
下载链接:
https://hf-mirror.com/datasets/malaysia-ai/Flash-Attention3-wheel
下载链接
链接失效反馈官方服务:
资源简介:
---
viewer: false
---
# Flash-Attention3-wheel
Flash Attention 3 wheels on commit [0e60e39473e8df549a20fb5353760f7a65b30e2d](https://github.com/Dao-AILab/flash-attention/commit/0e60e39473e8df549a20fb5353760f7a65b30e2d).
https://windreamer.github.io/flash-attention3-wheels/ much more banger!
## Build using H100
For PyTorch 2.6.0 12.6, 2.7.0 12.6, 2.7.0 12.8, 2.7.1 12.6, 2.7.1 12.8, minimum Python 3.9.
## Build using GH200 ARM64
For PyTorch 2.7.0 12.8, 2.7.1 12.8, minimum Python 3.9.
## Installation
```bash
# make sure rename or else your PIP is screaming
# eg remove `-2.7.1-12.8` PyTorch CUDA version
# check all compatible version and architecture in https://huggingface.co/datasets/malaysia-ai/Flash-Attention3-wheel/tree/main
# below just an example for PyTorch 2.7.1 CUDA 12.8 x64
wget https://huggingface.co/datasets/mesolitica/Flash-Attention3-whl/resolve/main/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64-2.7.1-12.8.whl -O flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
pip3 install flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
```
## Unit test
```python
import flash_attn_interface
import torch
import random
import numpy as np
import flash_attn
import torch.nn.functional as F
def generate_list_sum_n(n, length=5, min_val=5):
numbers = [min_val] * length
remaining = n - min_val * length
for _ in range(remaining):
numbers[random.randint(0, length - 1)] += 1
random.shuffle(numbers)
return numbers
def block_diagonal_concat_inverted(*masks, dtype=torch.bfloat16):
total_size = sum(mask.size(0) for mask in masks)
combined_mask = torch.zeros(total_size, total_size, dtype=dtype)
current_pos = 0
for mask in masks:
size = mask.size(0)
combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask
current_pos += size
min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min
inverted_mask = torch.where(combined_mask == 1, torch.tensor(0, dtype=dtype), min_value)
return inverted_mask.unsqueeze(0)
sequence_length = 4096
query_lens = np.array(generate_list_sum_n(sequence_length, length=20, min_val=10), dtype=np.int64)
min_dtype = torch.finfo(torch.bfloat16).min
masking = query_lens
masks = []
for m in masking:
masks.append(torch.tril(torch.ones(m, m)))
attention_mask = block_diagonal_concat_inverted(*masks).cuda()
q = torch.randn(1, sequence_length, 128, 128, dtype = torch.bfloat16).cuda()
k = torch.randn(1, sequence_length, 128, 128, dtype = torch.bfloat16).cuda()
v = torch.randn(1, sequence_length, 128, 128, dtype = torch.bfloat16).cuda()
out_sdpa = torch.nn.functional.scaled_dot_product_attention(
query = q.transpose(1, 2),
key = k.transpose(1, 2),
value = v.transpose(1, 2),
attn_mask = attention_mask[None],
)
cumsum = [0] + np.cumsum(query_lens).tolist()
max_cumsum = int(np.max(cumsum))
cu_seq_lens_q = torch.tensor(cumsum, dtype=torch.int32).cuda()
max_seqlen_q = np.max(query_lens)
out_flash2 = flash_attn.flash_attn_varlen_func(
q = q[0],
k = k[0],
v = v[0],
cu_seqlens_q = cu_seq_lens_q,
cu_seqlens_k = cu_seq_lens_q,
max_seqlen_q = max_seqlen_q,
max_seqlen_k = max_seqlen_q,
causal = True
)
out_flash3 = flash_attn_interface.flash_attn_varlen_func(
q = q[0],
k = k[0],
v = v[0],
cu_seqlens_q = cu_seq_lens_q,
cu_seqlens_k = cu_seq_lens_q,
max_seqlen_q = max_seqlen_q,
max_seqlen_k = max_seqlen_q,
causal = True,
)
assert torch.allclose(out_flash3, out_sdpa[0].transpose(0, 1), atol=0.125, rtol=0)
assert torch.allclose(out_flash3, out_flash2, atol=0.125, rtol=0)
```
提供机构:
malaysia-ai



