five

malaysia-ai/Flash-Attention3-wheel

收藏
Hugging Face2025-11-19 更新2026-01-03 收录
下载链接:
https://hf-mirror.com/datasets/malaysia-ai/Flash-Attention3-wheel
下载链接
链接失效反馈
官方服务:
资源简介:
--- viewer: false --- # Flash-Attention3-wheel Flash Attention 3 wheels on commit [0e60e39473e8df549a20fb5353760f7a65b30e2d](https://github.com/Dao-AILab/flash-attention/commit/0e60e39473e8df549a20fb5353760f7a65b30e2d). https://windreamer.github.io/flash-attention3-wheels/ much more banger! ## Build using H100 For PyTorch 2.6.0 12.6, 2.7.0 12.6, 2.7.0 12.8, 2.7.1 12.6, 2.7.1 12.8, minimum Python 3.9. ## Build using GH200 ARM64 For PyTorch 2.7.0 12.8, 2.7.1 12.8, minimum Python 3.9. ## Installation ```bash # make sure rename or else your PIP is screaming # eg remove `-2.7.1-12.8` PyTorch CUDA version # check all compatible version and architecture in https://huggingface.co/datasets/malaysia-ai/Flash-Attention3-wheel/tree/main # below just an example for PyTorch 2.7.1 CUDA 12.8 x64 wget https://huggingface.co/datasets/mesolitica/Flash-Attention3-whl/resolve/main/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64-2.7.1-12.8.whl -O flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl pip3 install flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl ``` ## Unit test ```python import flash_attn_interface import torch import random import numpy as np import flash_attn import torch.nn.functional as F def generate_list_sum_n(n, length=5, min_val=5): numbers = [min_val] * length remaining = n - min_val * length for _ in range(remaining): numbers[random.randint(0, length - 1)] += 1 random.shuffle(numbers) return numbers def block_diagonal_concat_inverted(*masks, dtype=torch.bfloat16): total_size = sum(mask.size(0) for mask in masks) combined_mask = torch.zeros(total_size, total_size, dtype=dtype) current_pos = 0 for mask in masks: size = mask.size(0) combined_mask[current_pos:current_pos + size, current_pos:current_pos + size] = mask current_pos += size min_value = torch.finfo(dtype).min if dtype.is_floating_point else torch.iinfo(dtype).min inverted_mask = torch.where(combined_mask == 1, torch.tensor(0, dtype=dtype), min_value) return inverted_mask.unsqueeze(0) sequence_length = 4096 query_lens = np.array(generate_list_sum_n(sequence_length, length=20, min_val=10), dtype=np.int64) min_dtype = torch.finfo(torch.bfloat16).min masking = query_lens masks = [] for m in masking: masks.append(torch.tril(torch.ones(m, m))) attention_mask = block_diagonal_concat_inverted(*masks).cuda() q = torch.randn(1, sequence_length, 128, 128, dtype = torch.bfloat16).cuda() k = torch.randn(1, sequence_length, 128, 128, dtype = torch.bfloat16).cuda() v = torch.randn(1, sequence_length, 128, 128, dtype = torch.bfloat16).cuda() out_sdpa = torch.nn.functional.scaled_dot_product_attention( query = q.transpose(1, 2), key = k.transpose(1, 2), value = v.transpose(1, 2), attn_mask = attention_mask[None], ) cumsum = [0] + np.cumsum(query_lens).tolist() max_cumsum = int(np.max(cumsum)) cu_seq_lens_q = torch.tensor(cumsum, dtype=torch.int32).cuda() max_seqlen_q = np.max(query_lens) out_flash2 = flash_attn.flash_attn_varlen_func( q = q[0], k = k[0], v = v[0], cu_seqlens_q = cu_seq_lens_q, cu_seqlens_k = cu_seq_lens_q, max_seqlen_q = max_seqlen_q, max_seqlen_k = max_seqlen_q, causal = True ) out_flash3 = flash_attn_interface.flash_attn_varlen_func( q = q[0], k = k[0], v = v[0], cu_seqlens_q = cu_seq_lens_q, cu_seqlens_k = cu_seq_lens_q, max_seqlen_q = max_seqlen_q, max_seqlen_k = max_seqlen_q, causal = True, ) assert torch.allclose(out_flash3, out_sdpa[0].transpose(0, 1), atol=0.125, rtol=0) assert torch.allclose(out_flash3, out_flash2, atol=0.125, rtol=0) ```
提供机构:
malaysia-ai
5,000+
优质数据集
54 个
任务类型
进入经典数据集
二维码
社区交流群

面向社区/商业的数据集话题

二维码
科研交流群

面向高校/科研机构的开源数据集话题

数据驱动未来

携手共赢发展

商业合作