five

m-a-p/SMuPT_v0_8192_770M

收藏
Hugging Face2024-01-13 更新2024-03-04 收录
下载链接:
https://hf-mirror.com/datasets/m-a-p/SMuPT_v0_8192_770M
下载链接
链接失效反馈
官方服务:
资源简介:
SMuPT是一个用于符号音乐生成的预训练模型系列。它基于大规模的符号音乐数据集进行训练,该数据集包含数百万首单音和多音轨的音乐作品,涵盖不同的流派和风格。模型采用LLama2架构,可用于旋律生成、伴奏生成和多轨音乐生成等下游任务。

SMuPT是一个用于符号音乐生成的预训练模型系列。它基于大规模的符号音乐数据集进行训练,该数据集包含数百万首单音和多音轨的音乐作品,涵盖不同的流派和风格。模型采用LLama2架构,可用于旋律生成、伴奏生成和多轨音乐生成等下游任务。
提供机构:
m-a-p
原始信息汇总

SMuPT: Symbolic Music Generative Pre-trained Transformer

SMuPT 是一系列用于符号音乐生成的预训练模型。它在大规模符号音乐数据集上进行训练,包括数百万种不同风格和流派的单声部和多声部作品。这些模型采用 LLama2 架构进行训练,可用于下游音乐生成任务,如旋律生成、伴奏生成和多轨音乐生成。

模型发布

  • 2024年9月1日:发布了一系列预训练的 SMuPT 模型,参数范围从 110M 到 1.3B。

模型架构

以下是 SMuPT-v0 模型的架构细节:

名称 参数数量 训练数据(音乐作品) 序列长度 隐藏层大小 层数 头数
SMuPT-v0-8192-110M 110M 7M x 5.8 轮 8192 768 12 12
SMuPT-v0-8192-345M 345M 7M x 4 轮 8192 1024 24 16
SMuPT-v0-8192-770M 770M 7M x 3 轮 8192 1280 36 20
SMuPT-v0-8192-1.3B 1.3B 7M x 2.2 轮 8192 1536 48 24

模型使用

使用预训练的 SMuPT 模型有多种方式,目前基于 Megatron-LM 的使用方法如下:

环境设置

  1. 拉取 Megatron-LM 代码库: shell mkdir -p /path/to/workspace && cd /path/to/workspace git clone https://github.com/NVIDIA/Megatron-LM.git

  2. 下载预训练的 SMuPT 模型检查点和词汇文件: shell mkdir -p /models/SMuPT_v0_8192_1.3B && cd /models/SMuPT_v0_8192_1.3B wget -O model_optim_rng.pt https://huggingface.co/m-a-p/SMuPT_v0_8192_1.3B/resolve/main/model_optim_rng.pt?download=true wget -O newline.vocab https://huggingface.co/m-a-p/SMuPT_v0_8192_1.3B/resolve/main/newline.vocab?download=true wget -O newline.txt https://huggingface.co/m-a-p/SMuPT_v0_8192_1.3B/resolve/main/newline.txt?download=true

  3. 推荐使用最新版本的 NGCs PyTorch 容器 进行 SMuPT 推理: shell docker run --gpus all -it --name megatron --shm-size=16g -v $PWD:/workspace -p 5000:5000 nvcr.io/nvidia/pytorch:23.11-py3 /bin/bash

启动推理服务器

进入容器后,可以启动 REST 服务器进行推理:

shell #!/bin/bash export CUDA_DEVICE_MAX_CONNECTIONS=1

DISTRIBUTED_ARGS="--nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 6000"

CHECKPOINT=/path/to/model/checkpoint/folder VOCAB_FILE=/path/to/vocab/file MERGE_FILE=/path/to/merge/file

MODEL_SIZE="1.3B" if [[ ${MODEL_SIZE} == "110M" ]]; then HIDDEN_SIZE=768; NUM_HEAD=12; NUM_QUERY_GROUP=12; NUM_LAYERS=12; FFN_HIDDEN_SIZE=3072; NORM_EPS=1e-5; elif [[ ${MODEL_SIZE} == "345M" ]]; then HIDDEN_SIZE=1024; NUM_HEAD=16; NUM_QUERY_GROUP=16; NUM_LAYERS=24; FFN_HIDDEN_SIZE=4096; NORM_EPS=1e-5; elif [[ ${MODEL_SIZE} == "770M" ]]; then HIDDEN_SIZE=1280; NUM_HEAD=20; NUM_QUERY_GROUP=20; NUM_LAYERS=36; FFN_HIDDEN_SIZE=5120; NORM_EPS=1e-5; elif [[ ${MODEL_SIZE} == "1.3B" ]]; then HIDDEN_SIZE=1536; NUM_HEAD=24; NUM_QUERY_GROUP=24; NUM_LAYERS=48; FFN_HIDDEN_SIZE=6144; NORM_EPS=1e-5; else echo "invalid MODEL_SIZE: ${MODEL_SIZE}"; exit 1 fi MAX_SEQ_LEN=8192 MAX_POSITION_EMBEDDINGS=8192

pip install flask-restful

torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--num-layers ${NUM_LAYERS}
--hidden-size ${HIDDEN_SIZE}
--ffn-hidden-size ${FFN_HIDDEN_SIZE} --load ${CHECKPOINT}
--group-query-attention --num-query-groups ${NUM_QUERY_GROUP} --position-embedding-type rope --num-attention-heads ${NUM_HEAD}
--max-position-embeddings ${MAX_POSITION_EMBEDDINGS}
--tokenizer-type GPT2BPETokenizer
--normalization RMSNorm --norm-epsilon ${NORM_EPS} --make-vocab-size-divisible-by 1 --swiglu --use-flash-attn --bf16
--micro-batch-size 1
--disable-bias-linear --no-bias-gelu-fusion --untie-embeddings-and-output-weights --seq-length ${MAX_SEQ_LEN}
--vocab-file $VOCAB_FILE
--merge-file $MERGE_FILE
--attention-dropout 0.0 --hidden-dropout 0.0 --weight-decay 1e-1 --clip-grad 1.0 --adam-beta1 0.9 --adam-beta2 0.95 --adam-eps 1e-8 --seed 42

查询服务器

使用 CURL 直接查询服务器,注意换行符 在词汇表中用 <n> 表示,因此在提示和生成的令牌中需要替换换行符:

shell curl http://localhost:6000/api -X PUT -H Content-Type: application/json; charset=UTF-8 -d {"prompts":["X:1<n>L:1/8<n>M:4/4<n>K:G<n>GA"], "tokens_to_generate":4096}

处理后的输出: shell X:1 L:1/8 M:4/4 K:G GA | B2 B2 B2 (cd) | B2 A2 z2 AB | c2 c2 c2 (de) | d4 z2 B2 | d2 d2 d2 e>d | c2 B2 z2 dB | A2 A2 A2 B2 | G4 z2 GA | B2 B2 B2 cd | B2 A2 z2 AB | c2 c2 e2 dc | d4 z2 GA | B2 B2 B2 cd | B2 A2 z2 dB | A3 G A2 B2 | G4 z2 |]

将生成的令牌编码为音频后,您将听到以下音乐。

5,000+
优质数据集
54 个
任务类型
进入经典数据集
二维码
社区交流群

面向社区/商业的数据集话题

二维码
科研交流群

面向高校/科研机构的开源数据集话题

数据驱动未来

携手共赢发展

商业合作