xingjianleng/U-ViT-coco
收藏Hugging Face2025-11-15 更新2025-12-20 收录
下载链接:
https://hf-mirror.com/datasets/xingjianleng/U-ViT-coco
下载链接
链接失效反馈官方服务:
资源简介:
---
license: mit
---
# U-ViT-coco
## Download the data
Download folders `datasets`, `fid_stats`, and `stable-diffusion`, and put them in an `assets` folder.
## Pre-processed MSCOCO dataset using modified code from [U-ViT](https://github.com/baofff/U-ViT)
- RGB images are center-cropped to 256 resolution before saving
- Latents pre-extracted from [SD-VAE](https://huggingface.co/stabilityai/sd-vae-ft-ema)
- Prompt features extracted using [CLIP-L/14](https://huggingface.co/openai/clip-vit-large-patch14)
- **Targeted for diffusion model training (with REPA / REPA-E support)**
## Dataset example code:
```python
import os
import random
from datasets import load_from_disk
import numpy as np
import torch
from torch.utils.data import Dataset
class DatasetFactory(object):
def __init__(self):
self.train = None
self.test = None
def get_split(self, split, labeled=False):
if split == "train":
dataset = self.train
elif split == "test":
dataset = self.test
else:
raise ValueError
if self.has_label:
return dataset #if labeled else UnlabeledDataset(dataset)
else:
assert not labeled
return dataset
def unpreprocess(self, v): # to B C H W and [0, 1]
v = 0.5 * (v + 1.)
v.clamp_(0., 1.)
return v
@property
def has_label(self):
return True
@property
def data_shape(self):
raise NotImplementedError
@property
def data_dim(self):
return int(np.prod(self.data_shape))
@property
def fid_stat(self):
return None
def sample_label(self, n_samples, device):
raise NotImplementedError
def label_prob(self, k):
raise NotImplementedError
class HFMSCOCOFeatureDataset(Dataset):
# the image features are got through sample
def __init__(self, root):
self.root = root
self.datasets = load_from_disk(root)
def __len__(self):
return len(self.datasets)
def __getitem__(self, index):
batch = self.datasets[index]
x = batch["image"] # PIL.Image
z = np.array(batch["moments"]) # np.array [8, 32, 32]
cs = batch["contexts"] # np.array [5, 77, 768]
x = np.array(x)
x = x.reshape(*x.shape[:2], -1).transpose(2, 0, 1)
k = random.randint(0, len(cs) - 1)
c = np.array(cs[k])
x = torch.from_numpy(x)
z = torch.from_numpy(z).float()
c = torch.from_numpy(c).float()
return x, z, c
class CFGDataset(Dataset): # for classifier free guidance
def __init__(self, dataset, p_uncond, empty_token):
self.dataset = dataset
self.p_uncond = p_uncond
self.empty_token = empty_token
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
x, z, y = self.dataset[item]
if random.random() < self.p_uncond:
y = self.empty_token
return x, z, y
class MSCOCO256Features(DatasetFactory): # the moments calculated by Stable Diffusion image encoder & the contexts calculated by clip
def __init__(self, path, cfg=True, p_uncond=0.1, mode='train'):
super().__init__()
print('Prepare dataset...')
if mode == 'val':
# self.test = MSCOCOFeatureDataset(os.path.join(path, 'val'))
self.test = HFMSCOCOFeatureDataset(os.path.join(path, 'val'))
assert len(self.test) == 40504
self.empty_context = torch.from_numpy(np.load(os.path.join(path, 'empty_context.npy'))).float()
else:
# self.train = MSCOCOFeatureDataset(os.path.join(path, 'train'))
self.train = HFMSCOCOFeatureDataset(os.path.join(path, 'train'))
assert len(self.train) == 82783
self.empty_context = torch.from_numpy(np.load(os.path.join(path, 'empty_context.npy'))).float()
if cfg: # classifier free guidance
assert p_uncond is not None
print(f'prepare the dataset for classifier free guidance with p_uncond={p_uncond}')
self.train = CFGDataset(self.train, p_uncond, self.empty_context)
@property
def data_shape(self):
return 4, 32, 32
@property
def fid_stat(self):
return f'assets/fid_stats/fid_stats_mscoco256_val.npz'
```
提供机构:
xingjianleng



