five

humair025/tts-token-shimmer-std

收藏
Hugging Face2025-11-05 更新2025-11-15 收录
下载链接:
https://hf-mirror.com/datasets/humair025/tts-token-shimmer-std
下载链接
链接失效反馈
官方服务:
资源简介:
```python """ SNAC Decoder - decoder.py Parses annotated_data or token list and decodes to WAV audio file Usage: # From annotated_data string python decoder.py --annotated "<|START_OF_USER|>...<|end|>" --output audio.wav # From token list python decoder.py --tokens "[123, 456, 789, ...]" --output audio.wav # From JSON file python decoder.py --json sample.json --output audio.wav """ import re import json import torch import torchaudio import argparse from snac import SNAC from typing import List, Union class SNACDecoder: def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"): """Initialize SNAC decoder model.""" print(f"Loading SNAC model on {device}...") self.device = device self.model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device).eval() if device == "cuda": self.model = self.model.half() self.sample_rate = 24000 print("✓ Model loaded successfully") def parse_annotated_data(self, annotated_data: str) -> List[int]: """ Parse annotated_data string to extract tokens. Format: <|START_OF_USER|><|USER|><|TEXT|>phonemes<|TEXT|><|VOICE|>voice<|VOICE|> <|END_OF_USER|><|GENERATE_SPEECH|><|start|><|SPEECH_123|><|SPEECH_456|>...<|end|> Returns: List of token integers """ print("Parsing annotated_data string...") # Extract everything between <|start|> and <|end|> match = re.search(r'<\|start\|>(.*?)<\|end\|>', annotated_data) if not match: raise ValueError("Could not find <|start|>...<|end|> in annotated_data") speech_section = match.group(1) # Extract all <|SPEECH_XXX|> tokens token_pattern = r'<\|SPEECH_(\d+)\|>' tokens = [int(m) for m in re.findall(token_pattern, speech_section)] if not tokens: raise ValueError("No tokens found in annotated_data") print(f"✓ Extracted {len(tokens)} tokens") return tokens def parse_token_list(self, token_str: str) -> List[int]: """ Parse token list string in various formats. Accepts: - "[123, 456, 789]" - "123, 456, 789" - "123 456 789" Returns: List of token integers """ print("Parsing token list...") # Remove brackets if present token_str = token_str.strip().strip('[]') # Split by comma or space if ',' in token_str: tokens = [int(x.strip()) for x in token_str.split(',') if x.strip()] else: tokens = [int(x.strip()) for x in token_str.split() if x.strip()] if not tokens: raise ValueError("No valid tokens found in input") print(f"✓ Parsed {len(tokens)} tokens") return tokens def redistribute_codes(self, flattened_tokens: List[int]) -> List[torch.Tensor]: """ Redistribute flattened tokens back to 3 hierarchical layers. Pattern: [L1, L2, L3, L3, L2, L3, L3] repeating Inverse operation of flatten_codes() Args: flattened_tokens: Flat list of tokens Returns: List of 3 tensors [layer_1, layer_2, layer_3] """ print("Redistributing tokens to 3 SNAC layers...") layer_1 = [] layer_2 = [] layer_3 = [] # Pattern: [L1, L2, L3, L3, L2, L3, L3] = 7 tokens per pattern num_patterns = (len(flattened_tokens) + 6) // 7 for i in range(num_patterns): base_idx = 7 * i if base_idx < len(flattened_tokens): layer_1.append(flattened_tokens[base_idx]) if base_idx + 1 < len(flattened_tokens): layer_2.append(flattened_tokens[base_idx + 1]) if base_idx + 2 < len(flattened_tokens): layer_3.append(flattened_tokens[base_idx + 2]) if base_idx + 3 < len(flattened_tokens): layer_3.append(flattened_tokens[base_idx + 3]) if base_idx + 4 < len(flattened_tokens): layer_2.append(flattened_tokens[base_idx + 4]) if base_idx + 5 < len(flattened_tokens): layer_3.append(flattened_tokens[base_idx + 5]) if base_idx + 6 < len(flattened_tokens): layer_3.append(flattened_tokens[base_idx + 6]) # Convert to tensors codes = [ torch.tensor(layer_1, dtype=torch.long).unsqueeze(0).to(self.device), torch.tensor(layer_2, dtype=torch.long).unsqueeze(0).to(self.device), torch.tensor(layer_3, dtype=torch.long).unsqueeze(0).to(self.device) ] print(f"✓ Redistributed to:") print(f" Layer 1: {codes[0].shape} ({codes[0].numel()} tokens)") print(f" Layer 2: {codes[1].shape} ({codes[1].numel()} tokens)") print(f" Layer 3: {codes[2].shape} ({codes[2].numel()} tokens)") return codes def decode_to_audio(self, tokens: List[int]) -> torch.Tensor: """ Decode token list to audio waveform. Args: tokens: Flattened token list Returns: Audio waveform tensor [1, samples] """ print("\nDecoding tokens to audio...") # Redistribute to 3 layers codes = self.redistribute_codes(tokens) # Decode with SNAC with torch.no_grad(): audio = self.model.decode(codes) # Ensure proper shape [1, samples] audio = audio.detach().cpu().float().squeeze() if audio.dim() == 0: audio = audio.unsqueeze(0) if audio.dim() == 1: audio = audio.unsqueeze(0) duration = audio.shape[-1] / self.sample_rate print(f"✓ Decoded audio:") print(f" Shape: {audio.shape}") print(f" Duration: {duration:.2f} seconds") print(f" Sample rate: {self.sample_rate} Hz") return audio def save_wav(self, audio: torch.Tensor, output_path: str): """Save audio tensor to WAV file.""" print(f"\nSaving audio to: {output_path}") torchaudio.save(output_path, audio, self.sample_rate) print(f"✓ Saved successfully") def decode_from_annotated(self, annotated_data: str, output_path: str): """ Complete pipeline: Parse annotated_data → Decode → Save WAV Args: annotated_data: Full annotated data string output_path: Output WAV file path """ print("="*70) print("DECODING FROM ANNOTATED_DATA") print("="*70) tokens = self.parse_annotated_data(annotated_data) audio = self.decode_to_audio(tokens) self.save_wav(audio, output_path) print("\n" + "="*70) print("✅ DECODING COMPLETE") print("="*70) def decode_from_tokens(self, tokens: Union[List[int], str], output_path: str): """ Complete pipeline: Parse tokens → Decode → Save WAV Args: tokens: Token list or string representation output_path: Output WAV file path """ print("="*70) print("DECODING FROM TOKEN LIST") print("="*70) if isinstance(tokens, str): tokens = self.parse_token_list(tokens) audio = self.decode_to_audio(tokens) self.save_wav(audio, output_path) print("\n" + "="*70) print("✅ DECODING COMPLETE") print("="*70) def decode_from_json(self, json_path: str, output_path: str, use_annotated: bool = True): """ Decode from JSON file containing dataset entry. Args: json_path: Path to JSON file output_path: Output WAV file path use_annotated: If True, parse annotated_data; else use tokens directly """ print("="*70) print("DECODING FROM JSON FILE") print("="*70) print(f"Reading: {json_path}") with open(json_path, 'r', encoding='utf-8') as f: data = json.load(f) print(f"✓ Loaded JSON data") # Display metadata if 'transcript' in data: print(f"\nTranscript: {data['transcript'][:80]}...") if 'phonemes' in data: print(f"Phonemes: {data['phonemes'][:80]}...") if 'voice' in data: print(f"Voice: {data['voice']}") if 'total_tokens' in data: print(f"Total tokens: {data['total_tokens']}") if 'duration_sec' in data: print(f"Expected duration: {data['duration_sec']}s") # Choose decoding method if use_annotated and 'annotated_data' in data: print("\nUsing annotated_data for decoding...") tokens = self.parse_annotated_data(data['annotated_data']) elif 'tokens' in data: print("\nUsing tokens array for decoding...") tokens = data['tokens'] else: raise ValueError("JSON must contain either 'annotated_data' or 'tokens'") audio = self.decode_to_audio(tokens) self.save_wav(audio, output_path) print("\n" + "="*70) print("✅ DECODING COMPLETE") print("="*70) def main(): parser = argparse.ArgumentParser( description='SNAC Decoder - Convert tokens or annotated_data to WAV audio', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # From annotated_data string python decoder.py --annotated "<|START_OF_USER|>...<|end|>" --output audio.wav # From token list python decoder.py --tokens "[123, 456, 789, 234, 567]" --output audio.wav # From JSON file (using annotated_data) python decoder.py --json sample.json --output audio.wav # From JSON file (using tokens directly) python decoder.py --json sample.json --output audio.wav --use-tokens """ ) # Input options (mutually exclusive) input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument('--annotated', type=str, help='Annotated data string') input_group.add_argument('--tokens', type=str, help='Token list as string: "[123, 456, ...]"') input_group.add_argument('--json', type=str, help='Path to JSON file') # Output parser.add_argument('--output', '-o', type=str, required=True, help='Output WAV file path') # Options parser.add_argument('--use-tokens', action='store_true', help='When using --json, decode from tokens array instead of annotated_data') parser.add_argument('--device', type=str, default=None, help='Device: cuda or cpu (default: auto-detect)') args = parser.parse_args() # Initialize decoder device = args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu") decoder = SNACDecoder(device=device) # Decode based on input type try: if args.annotated: decoder.decode_from_annotated(args.annotated, args.output) elif args.tokens: decoder.decode_from_tokens(args.tokens, args.output) elif args.json: use_annotated = not args.use_tokens decoder.decode_from_json(args.json, args.output, use_annotated=use_annotated) except Exception as e: print(f"\n❌ ERROR: {e}") import traceback traceback.print_exc() return 1 return 0 if __name__ == "__main__": exit(main()) ```
提供机构:
humair025
5,000+
优质数据集
54 个
任务类型
进入经典数据集
二维码
社区交流群

面向社区/商业的数据集话题

二维码
科研交流群

面向高校/科研机构的开源数据集话题

数据驱动未来

携手共赢发展

商业合作