UCF-101
收藏数据集概述
数据集下载
- 数据集名称: UCF-101
- 下载链接: UCF-101.rar
- 下载命令: bash cd [DATA_ROOT] wget https://www.crcv.ucf.edu/data/UCF101/UCF101.rar --no-check-certificate unrar x UCF101.rar
数据集预处理
- 预处理脚本:
split_ucf.py - 预处理命令: bash cd CoordTok/data python split_ucf.py --data_root [DATA_ROOT] --data_name UCF-101
数据集结构
-
数据集路径:
[DATA_ROOT]/UCF-101_train -
结构示例:
[DATA_ROOT]/UCF-101_train |-- class1 |-- video1.avi |-- video2.avi |-- ... |-- class2 |-- video1.avi |-- video2.avi |-- ... ...
训练脚本
- CoordTok训练脚本: bash torchrun --nnodes=1 --nproc_per_node=N train_coordtok.py --data_root [DATA_ROOT] --num_views 256 --num_iters 1000001 --accum_iter M --enc_embed_dim 1024 --enc_num_layers 24 --enc_num_heads 16 --enc_patch_num_layers 8 --dec_embed_dim 1024 --dec_num_layers 24 --dec_num_heads 16 --point_per_vid 1024 --allow_tf32 --lpips_loss_scale 0.0
评估脚本
-
CoordTok视频重建: python import torch from models.coordtok.coordtok_model import CoordTok from tools.utils_coordtok import decode_video
model = CoordTok(video_shape=(128,128,128), # Shape (T, H, W) enc_embed_dim=1024, enc_num_layers=24, enc_num_heads=16, enc_patch_size_xy=16, enc_patch_size_t=8, enc_patch_type=transformer, enc_patch_num_layers=8, latent_resolution_xy=16, latent_resolution_t=8, latent_n_features=8, latent_patch_size_xy=8, latent_patch_size_t=16, dec_embed_dim=1024, dec_num_layers=24, dec_num_heads=16, dec_patch_size_xy=8, dec_patch_size_t=1, lpips_loss_scale=0).cuda()
x = torch.zeros(1, 128, 128, 128, 3).cuda() # Shape (BS, T, H, W, 3) / Range [-1, 1] n_frames = torch.tensor([[128]], dtype=torch.int64).cuda() # Shape (BS, 1)
z_xy, z_yt, z_xt = model.encode(x, n_frames) # triplane representation
x_recon = decode_video(model, params=[z_xy, z_yt, z_xt], img_size=128, num_frames=128, patch_pred=(1, 8, 8), # Shape (dec_patch_size_t, dec_patch_size_xy, dec_patch_size_xy) max_num_frames=128, Nslice=1) # Range [-1, 1] x_recon = (x_recon+1)/2 x_recon = torch.clamp(x_recon, 0, 1) # Range [0, 1]




