humair025/tts-token-shimmer-std
收藏Hugging Face2025-11-05 更新2025-11-15 收录
下载链接:
https://hf-mirror.com/datasets/humair025/tts-token-shimmer-std
下载链接
链接失效反馈官方服务:
资源简介:
```python """
SNAC Decoder - decoder.py
Parses annotated_data or token list and decodes to WAV audio file
Usage:
# From annotated_data string
python decoder.py --annotated "<|START_OF_USER|>...<|end|>" --output audio.wav
# From token list
python decoder.py --tokens "[123, 456, 789, ...]" --output audio.wav
# From JSON file
python decoder.py --json sample.json --output audio.wav
"""
import re
import json
import torch
import torchaudio
import argparse
from snac import SNAC
from typing import List, Union
class SNACDecoder:
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
"""Initialize SNAC decoder model."""
print(f"Loading SNAC model on {device}...")
self.device = device
self.model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device).eval()
if device == "cuda":
self.model = self.model.half()
self.sample_rate = 24000
print("✓ Model loaded successfully")
def parse_annotated_data(self, annotated_data: str) -> List[int]:
"""
Parse annotated_data string to extract tokens.
Format:
<|START_OF_USER|><|USER|><|TEXT|>phonemes<|TEXT|><|VOICE|>voice<|VOICE|>
<|END_OF_USER|><|GENERATE_SPEECH|><|start|><|SPEECH_123|><|SPEECH_456|>...<|end|>
Returns:
List of token integers
"""
print("Parsing annotated_data string...")
# Extract everything between <|start|> and <|end|>
match = re.search(r'<\|start\|>(.*?)<\|end\|>', annotated_data)
if not match:
raise ValueError("Could not find <|start|>...<|end|> in annotated_data")
speech_section = match.group(1)
# Extract all <|SPEECH_XXX|> tokens
token_pattern = r'<\|SPEECH_(\d+)\|>'
tokens = [int(m) for m in re.findall(token_pattern, speech_section)]
if not tokens:
raise ValueError("No tokens found in annotated_data")
print(f"✓ Extracted {len(tokens)} tokens")
return tokens
def parse_token_list(self, token_str: str) -> List[int]:
"""
Parse token list string in various formats.
Accepts:
- "[123, 456, 789]"
- "123, 456, 789"
- "123 456 789"
Returns:
List of token integers
"""
print("Parsing token list...")
# Remove brackets if present
token_str = token_str.strip().strip('[]')
# Split by comma or space
if ',' in token_str:
tokens = [int(x.strip()) for x in token_str.split(',') if x.strip()]
else:
tokens = [int(x.strip()) for x in token_str.split() if x.strip()]
if not tokens:
raise ValueError("No valid tokens found in input")
print(f"✓ Parsed {len(tokens)} tokens")
return tokens
def redistribute_codes(self, flattened_tokens: List[int]) -> List[torch.Tensor]:
"""
Redistribute flattened tokens back to 3 hierarchical layers.
Pattern: [L1, L2, L3, L3, L2, L3, L3] repeating
Inverse operation of flatten_codes()
Args:
flattened_tokens: Flat list of tokens
Returns:
List of 3 tensors [layer_1, layer_2, layer_3]
"""
print("Redistributing tokens to 3 SNAC layers...")
layer_1 = []
layer_2 = []
layer_3 = []
# Pattern: [L1, L2, L3, L3, L2, L3, L3] = 7 tokens per pattern
num_patterns = (len(flattened_tokens) + 6) // 7
for i in range(num_patterns):
base_idx = 7 * i
if base_idx < len(flattened_tokens):
layer_1.append(flattened_tokens[base_idx])
if base_idx + 1 < len(flattened_tokens):
layer_2.append(flattened_tokens[base_idx + 1])
if base_idx + 2 < len(flattened_tokens):
layer_3.append(flattened_tokens[base_idx + 2])
if base_idx + 3 < len(flattened_tokens):
layer_3.append(flattened_tokens[base_idx + 3])
if base_idx + 4 < len(flattened_tokens):
layer_2.append(flattened_tokens[base_idx + 4])
if base_idx + 5 < len(flattened_tokens):
layer_3.append(flattened_tokens[base_idx + 5])
if base_idx + 6 < len(flattened_tokens):
layer_3.append(flattened_tokens[base_idx + 6])
# Convert to tensors
codes = [
torch.tensor(layer_1, dtype=torch.long).unsqueeze(0).to(self.device),
torch.tensor(layer_2, dtype=torch.long).unsqueeze(0).to(self.device),
torch.tensor(layer_3, dtype=torch.long).unsqueeze(0).to(self.device)
]
print(f"✓ Redistributed to:")
print(f" Layer 1: {codes[0].shape} ({codes[0].numel()} tokens)")
print(f" Layer 2: {codes[1].shape} ({codes[1].numel()} tokens)")
print(f" Layer 3: {codes[2].shape} ({codes[2].numel()} tokens)")
return codes
def decode_to_audio(self, tokens: List[int]) -> torch.Tensor:
"""
Decode token list to audio waveform.
Args:
tokens: Flattened token list
Returns:
Audio waveform tensor [1, samples]
"""
print("\nDecoding tokens to audio...")
# Redistribute to 3 layers
codes = self.redistribute_codes(tokens)
# Decode with SNAC
with torch.no_grad():
audio = self.model.decode(codes)
# Ensure proper shape [1, samples]
audio = audio.detach().cpu().float().squeeze()
if audio.dim() == 0:
audio = audio.unsqueeze(0)
if audio.dim() == 1:
audio = audio.unsqueeze(0)
duration = audio.shape[-1] / self.sample_rate
print(f"✓ Decoded audio:")
print(f" Shape: {audio.shape}")
print(f" Duration: {duration:.2f} seconds")
print(f" Sample rate: {self.sample_rate} Hz")
return audio
def save_wav(self, audio: torch.Tensor, output_path: str):
"""Save audio tensor to WAV file."""
print(f"\nSaving audio to: {output_path}")
torchaudio.save(output_path, audio, self.sample_rate)
print(f"✓ Saved successfully")
def decode_from_annotated(self, annotated_data: str, output_path: str):
"""
Complete pipeline: Parse annotated_data → Decode → Save WAV
Args:
annotated_data: Full annotated data string
output_path: Output WAV file path
"""
print("="*70)
print("DECODING FROM ANNOTATED_DATA")
print("="*70)
tokens = self.parse_annotated_data(annotated_data)
audio = self.decode_to_audio(tokens)
self.save_wav(audio, output_path)
print("\n" + "="*70)
print("✅ DECODING COMPLETE")
print("="*70)
def decode_from_tokens(self, tokens: Union[List[int], str], output_path: str):
"""
Complete pipeline: Parse tokens → Decode → Save WAV
Args:
tokens: Token list or string representation
output_path: Output WAV file path
"""
print("="*70)
print("DECODING FROM TOKEN LIST")
print("="*70)
if isinstance(tokens, str):
tokens = self.parse_token_list(tokens)
audio = self.decode_to_audio(tokens)
self.save_wav(audio, output_path)
print("\n" + "="*70)
print("✅ DECODING COMPLETE")
print("="*70)
def decode_from_json(self, json_path: str, output_path: str, use_annotated: bool = True):
"""
Decode from JSON file containing dataset entry.
Args:
json_path: Path to JSON file
output_path: Output WAV file path
use_annotated: If True, parse annotated_data; else use tokens directly
"""
print("="*70)
print("DECODING FROM JSON FILE")
print("="*70)
print(f"Reading: {json_path}")
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"✓ Loaded JSON data")
# Display metadata
if 'transcript' in data:
print(f"\nTranscript: {data['transcript'][:80]}...")
if 'phonemes' in data:
print(f"Phonemes: {data['phonemes'][:80]}...")
if 'voice' in data:
print(f"Voice: {data['voice']}")
if 'total_tokens' in data:
print(f"Total tokens: {data['total_tokens']}")
if 'duration_sec' in data:
print(f"Expected duration: {data['duration_sec']}s")
# Choose decoding method
if use_annotated and 'annotated_data' in data:
print("\nUsing annotated_data for decoding...")
tokens = self.parse_annotated_data(data['annotated_data'])
elif 'tokens' in data:
print("\nUsing tokens array for decoding...")
tokens = data['tokens']
else:
raise ValueError("JSON must contain either 'annotated_data' or 'tokens'")
audio = self.decode_to_audio(tokens)
self.save_wav(audio, output_path)
print("\n" + "="*70)
print("✅ DECODING COMPLETE")
print("="*70)
def main():
parser = argparse.ArgumentParser(
description='SNAC Decoder - Convert tokens or annotated_data to WAV audio',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# From annotated_data string
python decoder.py --annotated "<|START_OF_USER|>...<|end|>" --output audio.wav
# From token list
python decoder.py --tokens "[123, 456, 789, 234, 567]" --output audio.wav
# From JSON file (using annotated_data)
python decoder.py --json sample.json --output audio.wav
# From JSON file (using tokens directly)
python decoder.py --json sample.json --output audio.wav --use-tokens
"""
)
# Input options (mutually exclusive)
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument('--annotated', type=str,
help='Annotated data string')
input_group.add_argument('--tokens', type=str,
help='Token list as string: "[123, 456, ...]"')
input_group.add_argument('--json', type=str,
help='Path to JSON file')
# Output
parser.add_argument('--output', '-o', type=str, required=True,
help='Output WAV file path')
# Options
parser.add_argument('--use-tokens', action='store_true',
help='When using --json, decode from tokens array instead of annotated_data')
parser.add_argument('--device', type=str, default=None,
help='Device: cuda or cpu (default: auto-detect)')
args = parser.parse_args()
# Initialize decoder
device = args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")
decoder = SNACDecoder(device=device)
# Decode based on input type
try:
if args.annotated:
decoder.decode_from_annotated(args.annotated, args.output)
elif args.tokens:
decoder.decode_from_tokens(args.tokens, args.output)
elif args.json:
use_annotated = not args.use_tokens
decoder.decode_from_json(args.json, args.output, use_annotated=use_annotated)
except Exception as e:
print(f"\n❌ ERROR: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())
```
提供机构:
humair025



