From 3e93bf93808e72c3a9972580abe08b6a796053d7 Mon Sep 17 00:00:00 2001 From: Aleksey Smolenchuk Date: Sun, 26 Nov 2023 13:49:53 -0800 Subject: [PATCH] Added importable tts.py --- .gitignore | 2 + Demo/Inference_LJSpeech.ipynb | 10 +- Demo/Inference_LibriTTS.ipynb | 16 +- MANIFEST.in | 1 + README.md | 34 +++ setup.py | 22 ++ {Modules => styletts2/Modules}/__init__.py | 0 .../Modules}/diffusion/__init__.py | 0 .../Modules}/diffusion/diffusion.py | 0 .../Modules}/diffusion/modules.py | 0 .../Modules}/diffusion/sampler.py | 0 .../Modules}/diffusion/utils.py | 0 .../Modules}/discriminators.py | 0 {Modules => styletts2/Modules}/hifigan.py | 0 {Modules => styletts2/Modules}/istftnet.py | 0 {Modules => styletts2/Modules}/slmadv.py | 0 {Modules => styletts2/Modules}/utils.py | 0 {Utils => styletts2/Utils}/ASR/__init__.py | 0 {Utils => styletts2/Utils}/ASR/config.yml | 0 .../Utils}/ASR/epoch_00080.pth | Bin {Utils => styletts2/Utils}/ASR/layers.py | 0 {Utils => styletts2/Utils}/ASR/models.py | 0 {Utils => styletts2/Utils}/JDC/__init__.py | 0 {Utils => styletts2/Utils}/JDC/bst.t7 | Bin {Utils => styletts2/Utils}/JDC/model.py | 0 styletts2/Utils/PLBERT/__init__.py | 0 {Utils => styletts2/Utils}/PLBERT/config.yml | 0 .../Utils}/PLBERT/step_1000000.t7 | Bin {Utils => styletts2/Utils}/PLBERT/util.py | 0 {Utils => styletts2/Utils}/__init__.py | 0 styletts2/__init__.py | 1 + losses.py => styletts2/losses.py | 0 models.py => styletts2/models.py | 16 +- text_utils.py => styletts2/text_utils.py | 0 styletts2/tts.py | 207 ++++++++++++++++++ utils.py => styletts2/utils.py | 1 - train_finetune.py | 16 +- train_first.py | 8 +- train_second.py | 16 +- 39 files changed, 308 insertions(+), 42 deletions(-) create mode 100644 .gitignore create mode 100644 MANIFEST.in create mode 100644 setup.py rename {Modules => styletts2/Modules}/__init__.py (100%) rename {Modules => styletts2/Modules}/diffusion/__init__.py (100%) rename {Modules => styletts2/Modules}/diffusion/diffusion.py (100%) rename {Modules => styletts2/Modules}/diffusion/modules.py (100%) rename {Modules => styletts2/Modules}/diffusion/sampler.py (100%) rename {Modules => styletts2/Modules}/diffusion/utils.py (100%) rename {Modules => styletts2/Modules}/discriminators.py (100%) rename {Modules => styletts2/Modules}/hifigan.py (100%) rename {Modules => styletts2/Modules}/istftnet.py (100%) rename {Modules => styletts2/Modules}/slmadv.py (100%) rename {Modules => styletts2/Modules}/utils.py (100%) rename {Utils => styletts2/Utils}/ASR/__init__.py (100%) rename {Utils => styletts2/Utils}/ASR/config.yml (100%) rename {Utils => styletts2/Utils}/ASR/epoch_00080.pth (100%) rename {Utils => styletts2/Utils}/ASR/layers.py (100%) rename {Utils => styletts2/Utils}/ASR/models.py (100%) rename {Utils => styletts2/Utils}/JDC/__init__.py (100%) rename {Utils => styletts2/Utils}/JDC/bst.t7 (100%) rename {Utils => styletts2/Utils}/JDC/model.py (100%) create mode 100644 styletts2/Utils/PLBERT/__init__.py rename {Utils => styletts2/Utils}/PLBERT/config.yml (100%) rename {Utils => styletts2/Utils}/PLBERT/step_1000000.t7 (100%) rename {Utils => styletts2/Utils}/PLBERT/util.py (100%) rename {Utils => styletts2/Utils}/__init__.py (100%) create mode 100644 styletts2/__init__.py rename losses.py => styletts2/losses.py (100%) rename models.py => styletts2/models.py (98%) rename text_utils.py => styletts2/text_utils.py (100%) create mode 100644 styletts2/tts.py rename utils.py => styletts2/utils.py (99%) diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..3061b41e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.egg-info +build diff --git a/Demo/Inference_LJSpeech.ipynb b/Demo/Inference_LJSpeech.ipynb index 3a6923e4..7b41741c 100644 --- a/Demo/Inference_LJSpeech.ipynb +++ b/Demo/Inference_LJSpeech.ipynb @@ -65,9 +65,9 @@ "import librosa\n", "from nltk.tokenize import word_tokenize\n", "\n", - "from models import *\n", - "from utils import *\n", - "from text_utils import TextCleaner\n", + "from styletts2.models import *\n", + "from styletts2.utils import *\n", + "from styletts2.text_utils import TextCleaner\n", "textclenaer = TextCleaner()\n", "\n", "%matplotlib inline" @@ -160,7 +160,7 @@ "pitch_extractor = load_F0_models(F0_path)\n", "\n", "# load BERT model\n", - "from Utils.PLBERT.util import load_plbert\n", + "from styletts2.Utils.PLBERT.util import load_plbert\n", "BERT_path = config.get('PLBERT_dir', False)\n", "plbert = load_plbert(BERT_path)" ] @@ -221,7 +221,7 @@ "metadata": {}, "outputs": [], "source": [ - "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule" + "from styletts2.Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule" ] }, { diff --git a/Demo/Inference_LibriTTS.ipynb b/Demo/Inference_LibriTTS.ipynb index 4b85bf5f..d2d37294 100644 --- a/Demo/Inference_LibriTTS.ipynb +++ b/Demo/Inference_LibriTTS.ipynb @@ -67,9 +67,9 @@ "import librosa\n", "from nltk.tokenize import word_tokenize\n", "\n", - "from models import *\n", - "from utils import *\n", - "from text_utils import TextCleaner\n", + "from styletts2.models import *\n", + "from styletts2.utils import *\n", + "from styletts2.text_utils import TextCleaner\n", "textclenaer = TextCleaner()\n", "\n", "%matplotlib inline" @@ -160,7 +160,7 @@ "pitch_extractor = load_F0_models(F0_path)\n", "\n", "# load BERT model\n", - "from Utils.PLBERT.util import load_plbert\n", + "from styletts2.Utils.PLBERT.util import load_plbert\n", "BERT_path = config.get('PLBERT_dir', False)\n", "plbert = load_plbert(BERT_path)" ] @@ -222,7 +222,7 @@ "metadata": {}, "outputs": [], "source": [ - "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule" + "from styletts2.Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule" ] }, { @@ -1133,9 +1133,9 @@ ], "metadata": { "kernelspec": { - "display_name": "NLP", + "display_name": "Python 3", "language": "python", - "name": "nlp" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -1147,7 +1147,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..7b5b5064 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include styletts2 * diff --git a/README.md b/README.md index 22050c3f..0bdee37b 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,40 @@ Please make sure you have the LibriTTS checkpoint downloaded and unzipped under - **Out of memory after `joint_epoch`**: This is likely because your GPU RAM is not big enough for SLM adversarial training run. You may skip that but the quality could be worse. Setting `joint_epoch` a larger number than `epochs` could skip the SLM advesariral training. ## Inference + +Quick start example: + +```python +from styletts2 import TTS +import sounddevice as sd +import phonemizer + +tts = TTS.load_model( + config_path="hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/config.yml", + checkpoint_path="hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth" +) + +es_phonemizer = phonemizer.backend.EspeakBackend( + language='en-us', + preserve_punctuation=True, + with_stress=True +) + +style = tts.compute_style('../tts-server/tts_server/voices/en-f-1.wav') + +wav, _ = tts.inference( + "This is a text! Hello world! How are you? What's your name?", + style, + phonemizer=es_phonemizer, + alpha=0.3, + beta=0.7, + diffusion_steps=10, + embedding_scale=2) + +sd.play(wav, 24000) +sd.wait() +``` + Please refer to [Inference_LJSpeech.ipynb](https://github.com/yl4579/StyleTTS2/blob/main/Demo/Inference_LJSpeech.ipynb) (single-speaker) and [Inference_LibriTTS.ipynb](https://github.com/yl4579/StyleTTS2/blob/main/Demo/Inference_LibriTTS.ipynb) (multi-speaker) for details. For LibriTTS, you will also need to download [reference_audio.zip](https://huggingface.co/yl4579/StyleTTS2-LibriTTS/resolve/main/reference_audio.zip) and unzip it under the `demo` before running the demo. - The pretrained StyleTTS 2 on LJSpeech corpus in 24 kHz can be downloaded at [https://huggingface.co/yl4579/StyleTTS2-LJSpeech/tree/main](https://huggingface.co/yl4579/StyleTTS2-LJSpeech/tree/main). diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..ee01b049 --- /dev/null +++ b/setup.py @@ -0,0 +1,22 @@ +from setuptools import setup, find_packages + +setup( + name="styletts2", + version="0.0.1", + packages=find_packages(), + include_package_data=True, + install_requires=[ + "cached_path", + "nltk", + "scipy", + "numpy", + "munch", + "librosa", + "sounddevice", + "einops", + "einops_exts", + "transformers", + "matplotlib", + "monotonic_align @ git+https://github.com/resemble-ai/monotonic_align.git", + ] +) diff --git a/Modules/__init__.py b/styletts2/Modules/__init__.py similarity index 100% rename from Modules/__init__.py rename to styletts2/Modules/__init__.py diff --git a/Modules/diffusion/__init__.py b/styletts2/Modules/diffusion/__init__.py similarity index 100% rename from Modules/diffusion/__init__.py rename to styletts2/Modules/diffusion/__init__.py diff --git a/Modules/diffusion/diffusion.py b/styletts2/Modules/diffusion/diffusion.py similarity index 100% rename from Modules/diffusion/diffusion.py rename to styletts2/Modules/diffusion/diffusion.py diff --git a/Modules/diffusion/modules.py b/styletts2/Modules/diffusion/modules.py similarity index 100% rename from Modules/diffusion/modules.py rename to styletts2/Modules/diffusion/modules.py diff --git a/Modules/diffusion/sampler.py b/styletts2/Modules/diffusion/sampler.py similarity index 100% rename from Modules/diffusion/sampler.py rename to styletts2/Modules/diffusion/sampler.py diff --git a/Modules/diffusion/utils.py b/styletts2/Modules/diffusion/utils.py similarity index 100% rename from Modules/diffusion/utils.py rename to styletts2/Modules/diffusion/utils.py diff --git a/Modules/discriminators.py b/styletts2/Modules/discriminators.py similarity index 100% rename from Modules/discriminators.py rename to styletts2/Modules/discriminators.py diff --git a/Modules/hifigan.py b/styletts2/Modules/hifigan.py similarity index 100% rename from Modules/hifigan.py rename to styletts2/Modules/hifigan.py diff --git a/Modules/istftnet.py b/styletts2/Modules/istftnet.py similarity index 100% rename from Modules/istftnet.py rename to styletts2/Modules/istftnet.py diff --git a/Modules/slmadv.py b/styletts2/Modules/slmadv.py similarity index 100% rename from Modules/slmadv.py rename to styletts2/Modules/slmadv.py diff --git a/Modules/utils.py b/styletts2/Modules/utils.py similarity index 100% rename from Modules/utils.py rename to styletts2/Modules/utils.py diff --git a/Utils/ASR/__init__.py b/styletts2/Utils/ASR/__init__.py similarity index 100% rename from Utils/ASR/__init__.py rename to styletts2/Utils/ASR/__init__.py diff --git a/Utils/ASR/config.yml b/styletts2/Utils/ASR/config.yml similarity index 100% rename from Utils/ASR/config.yml rename to styletts2/Utils/ASR/config.yml diff --git a/Utils/ASR/epoch_00080.pth b/styletts2/Utils/ASR/epoch_00080.pth similarity index 100% rename from Utils/ASR/epoch_00080.pth rename to styletts2/Utils/ASR/epoch_00080.pth diff --git a/Utils/ASR/layers.py b/styletts2/Utils/ASR/layers.py similarity index 100% rename from Utils/ASR/layers.py rename to styletts2/Utils/ASR/layers.py diff --git a/Utils/ASR/models.py b/styletts2/Utils/ASR/models.py similarity index 100% rename from Utils/ASR/models.py rename to styletts2/Utils/ASR/models.py diff --git a/Utils/JDC/__init__.py b/styletts2/Utils/JDC/__init__.py similarity index 100% rename from Utils/JDC/__init__.py rename to styletts2/Utils/JDC/__init__.py diff --git a/Utils/JDC/bst.t7 b/styletts2/Utils/JDC/bst.t7 similarity index 100% rename from Utils/JDC/bst.t7 rename to styletts2/Utils/JDC/bst.t7 diff --git a/Utils/JDC/model.py b/styletts2/Utils/JDC/model.py similarity index 100% rename from Utils/JDC/model.py rename to styletts2/Utils/JDC/model.py diff --git a/styletts2/Utils/PLBERT/__init__.py b/styletts2/Utils/PLBERT/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Utils/PLBERT/config.yml b/styletts2/Utils/PLBERT/config.yml similarity index 100% rename from Utils/PLBERT/config.yml rename to styletts2/Utils/PLBERT/config.yml diff --git a/Utils/PLBERT/step_1000000.t7 b/styletts2/Utils/PLBERT/step_1000000.t7 similarity index 100% rename from Utils/PLBERT/step_1000000.t7 rename to styletts2/Utils/PLBERT/step_1000000.t7 diff --git a/Utils/PLBERT/util.py b/styletts2/Utils/PLBERT/util.py similarity index 100% rename from Utils/PLBERT/util.py rename to styletts2/Utils/PLBERT/util.py diff --git a/Utils/__init__.py b/styletts2/Utils/__init__.py similarity index 100% rename from Utils/__init__.py rename to styletts2/Utils/__init__.py diff --git a/styletts2/__init__.py b/styletts2/__init__.py new file mode 100644 index 00000000..721ab2af --- /dev/null +++ b/styletts2/__init__.py @@ -0,0 +1 @@ +from .tts import TTS diff --git a/losses.py b/styletts2/losses.py similarity index 100% rename from losses.py rename to styletts2/losses.py diff --git a/models.py b/styletts2/models.py similarity index 98% rename from models.py rename to styletts2/models.py index 84bbb03d..4f90068e 100644 --- a/models.py +++ b/styletts2/models.py @@ -12,14 +12,14 @@ import torch.nn.functional as F from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from Utils.ASR.models import ASRCNN -from Utils.JDC.model import JDCNet +from .Utils.ASR.models import ASRCNN +from .Utils.JDC.model import JDCNet -from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution -from Modules.diffusion.modules import Transformer1d, StyleTransformer1d -from Modules.diffusion.diffusion import AudioDiffusionConditional +from .Modules.diffusion.sampler import KDiffusion, LogNormalDistribution +from .Modules.diffusion.modules import Transformer1d, StyleTransformer1d +from .Modules.diffusion.diffusion import AudioDiffusionConditional -from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator +from .Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator from munch import Munch import yaml @@ -615,7 +615,7 @@ def build_model(args, text_aligner, pitch_extractor, bert): assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown' if args.decoder.type == "istftnet": - from Modules.istftnet import Decoder + from .Modules.istftnet import Decoder decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels, resblock_kernel_sizes = args.decoder.resblock_kernel_sizes, upsample_rates = args.decoder.upsample_rates, @@ -624,7 +624,7 @@ def build_model(args, text_aligner, pitch_extractor, bert): upsample_kernel_sizes=args.decoder.upsample_kernel_sizes, gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size) else: - from Modules.hifigan import Decoder + from .Modules.hifigan import Decoder decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels, resblock_kernel_sizes = args.decoder.resblock_kernel_sizes, upsample_rates = args.decoder.upsample_rates, diff --git a/text_utils.py b/styletts2/text_utils.py similarity index 100% rename from text_utils.py rename to styletts2/text_utils.py diff --git a/styletts2/tts.py b/styletts2/tts.py new file mode 100644 index 00000000..5d8aacdb --- /dev/null +++ b/styletts2/tts.py @@ -0,0 +1,207 @@ +import torch +import torchaudio +import yaml +import librosa +import nltk +import os +import random + +import numpy as np +from nltk.tokenize import word_tokenize +from cached_path import cached_path + +# Import necessary modules and functions +from .Modules.diffusion.sampler import ADPM2Sampler, DiffusionSampler, KarrasSchedule +from .Utils.PLBERT.util import load_plbert +from .models import build_model, load_ASR_models, load_F0_models +from .text_utils import TextCleaner +from .utils import recursive_munch + + +class TTS: + def __init__(self, model_params, model, device): + if not nltk.find('tokenizers/punkt'): + nltk.download('punkt') + + self.model_params = model_params + self.model = model + self.device = device + self.text_cleaner = TextCleaner() + self.schedule = KarrasSchedule( + sigma_min=0.0001, sigma_max=3.0, rho=9.0) + self.sampler = DiffusionSampler( + model.diffusion.diffusion, + sampler=ADPM2Sampler(), + sigma_schedule=self.schedule, + clamp=False + ) + + @classmethod + def load_model(cls, config_path, checkpoint_path): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + config = yaml.safe_load(open(str(cached_path(config_path)))) + + # Process paths and load components + ASR_config = cls.fix_path(config.get('ASR_config', False)) + ASR_path = cls.fix_path(config.get('ASR_path', False)) + F0_path = cls.fix_path(config.get('F0_path', False)) + BERT_path = cls.fix_path(config.get('PLBERT_dir', False)) + + text_aligner = load_ASR_models(ASR_path, ASR_config) + pitch_extractor = load_F0_models(F0_path) + plbert = load_plbert(BERT_path) + + model_params = recursive_munch(config['model_params']) + model = build_model(model_params, text_aligner, + pitch_extractor, plbert) + + # Load state dicts + params_whole = torch.load( + str(cached_path(checkpoint_path)), map_location='cpu') + params = params_whole['net'] + cls.load_state_dicts(model, params) + + [model[key].eval() for key in model] + [model[key].to(device) for key in model] + + return cls(model_params, model, device) + + @staticmethod + def fix_path(path): + # if path is relative, make it absolute + if not os.path.isabs(path): + path = os.path.join(os.path.dirname(__file__), path) + return path + + @staticmethod + def load_state_dicts(model, params): + for key in model: + if key in params: + state_dict = params[key] + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + try: + model[key].load_state_dict(state_dict) + except RuntimeError as e: + print(f"Error loading state dict for {key}: {e}") + + @staticmethod + def preprocess_audio(wave, mean=-4, std=4): + wave_tensor = torch.from_numpy(wave).float() + mel_tensor = torchaudio.transforms.MelSpectrogram( + n_mels=80, n_fft=2048, win_length=1200, hop_length=300)(wave_tensor) + mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std + return mel_tensor + + @staticmethod + def length_to_mask(lengths): + mask = torch.arange(lengths.max()).unsqueeze( + 0).expand(lengths.shape[0], -1).type_as(lengths) + mask = torch.gt(mask+1, lengths.unsqueeze(1)) + return mask + + @staticmethod + def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + def compute_style(self, path): + wave, sr = librosa.load(path, sr=24000) + audio, index = librosa.effects.trim(wave, top_db=30) + if sr != 24000: + audio = librosa.resample(audio, sr, 24000) + mel_tensor = self.preprocess_audio(audio).to(self.device) + + with torch.no_grad(): + ref_s = self.model.style_encoder(mel_tensor.unsqueeze(1)) + ref_p = self.model.predictor_encoder(mel_tensor.unsqueeze(1)) + + return torch.cat([ref_s, ref_p], dim=1) + + def inference(self, text, ref_s, prev_s=None, alpha=0.3, beta=0.7, t=0.7, phonemizer=None, diffusion_steps=5, embedding_scale=1): + if phonemizer is None: + raise ValueError("Phonemizer is required for inference") + + # Preprocess text + text = text.strip() + ps = phonemizer.phonemize([text]) + ps = ' '.join(word_tokenize(ps[0])) + + # Prepare tokens + tokens = torch.LongTensor( + [0] + self.text_cleaner(ps)).to(self.device).unsqueeze(0) + + with torch.no_grad(): + input_lengths = torch.LongTensor( + [tokens.shape[-1]]).to(self.device) + text_mask = self.length_to_mask(input_lengths).to(self.device) + + # Encode text + t_en = self.model.text_encoder(tokens, input_lengths, text_mask) + bert_dur = self.model.bert( + tokens, attention_mask=(~text_mask).int()) + d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) + + # Predict style + s_pred = self.sampler( + noise=torch.randn((1, 256)).unsqueeze(1).to(self.device), + embedding=bert_dur, + embedding_scale=embedding_scale, + features=ref_s, # reference from the same speaker as the embedding + num_steps=diffusion_steps + ).squeeze(1) + + if prev_s is not None: + # Convex combination of previous and current style + s_pred = t * prev_s + (1 - t) * s_pred + + s = s_pred[:, 128:] + ref = s_pred[:, :128] + + ref = alpha * ref + (1 - alpha) * ref_s[:, :128] + s = beta * s + (1 - beta) * ref_s[:, 128:] + + s_pred = torch.cat([ref, s], dim=-1) + + # Predict duration + d = self.model.predictor.text_encoder( + d_en, s, input_lengths, text_mask) + x, _ = self.model.predictor.lstm(d) + duration = self.model.predictor.duration_proj(x) + duration = torch.sigmoid(duration).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + + # Create alignment target + pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 + c_frame += int(pred_dur[i].data) + + # Encode prosody + en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device)) + if self.model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(en) + asr_new[:, :, 0] = en[:, :, 0] + asr_new[:, :, 1:] = en[:, :, 0:-1] + en = asr_new + + # Predict F0 and N + F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s) + + asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device)) + if self.model_params.decoder.type == "hifigan": + asr_new = torch.zeros_like(asr) + asr_new[:, :, 0] = asr[:, :, 0] + asr_new[:, :, 1:] = asr[:, :, 0:-1] + asr = asr_new + + # Decode + out = self.model.decoder( + asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) + + # Fix weird pulse at the end later + return out.squeeze().cpu().numpy()[..., :-100], s_pred diff --git a/utils.py b/styletts2/utils.py similarity index 99% rename from utils.py rename to styletts2/utils.py index c2206d92..1feac2c6 100644 --- a/utils.py +++ b/styletts2/utils.py @@ -71,4 +71,3 @@ def recursive_munch(d): def log_print(message, logger): logger.info(message) print(message) - \ No newline at end of file diff --git a/train_finetune.py b/train_finetune.py index 3c650747..d6564e68 100644 --- a/train_finetune.py +++ b/train_finetune.py @@ -17,16 +17,16 @@ from meldataset import build_dataloader -from Utils.ASR.models import ASRCNN -from Utils.JDC.model import JDCNet -from Utils.PLBERT.util import load_plbert +from styletts2.Utils.ASR.models import ASRCNN +from styletts2.Utils.JDC.model import JDCNet +from styletts2.Utils.PLBERT.util import load_plbert -from models import * -from losses import * -from utils import * +from styletts2.models import * +from styletts2.losses import * +from styletts2.utils import * -from Modules.slmadv import SLMAdversarialLoss -from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule +from styletts2.Modules.slmadv import SLMAdversarialLoss +from styletts2.Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule from optimizers import build_optimizer diff --git a/train_first.py b/train_first.py index eaa8fe64..1f69b199 100644 --- a/train_first.py +++ b/train_first.py @@ -21,10 +21,10 @@ import torchaudio import librosa -from models import * +from styletts2.models import * from meldataset import build_dataloader -from utils import * -from losses import * +from styletts2.utils import * +from styletts2.losses import * from optimizers import build_optimizer import time @@ -108,7 +108,7 @@ def main(config_path): pitch_extractor = load_F0_models(F0_path) # load BERT model - from Utils.PLBERT.util import load_plbert + from styletts2.Utils.PLBERT.util import load_plbert BERT_path = config.get('PLBERT_dir', False) plbert = load_plbert(BERT_path) diff --git a/train_second.py b/train_second.py index 848de51f..fed37dd0 100644 --- a/train_second.py +++ b/train_second.py @@ -17,16 +17,16 @@ from meldataset import build_dataloader -from Utils.ASR.models import ASRCNN -from Utils.JDC.model import JDCNet -from Utils.PLBERT.util import load_plbert +from styletts2.Utils.ASR.models import ASRCNN +from styletts2.Utils.JDC.model import JDCNet +from styletts2.Utils.PLBERT.util import load_plbert -from models import * -from losses import * -from utils import * +from styletts2.models import * +from styletts2.losses import * +from styletts2.utils import * -from Modules.slmadv import SLMAdversarialLoss -from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule +from styletts2.Modules.slmadv import SLMAdversarialLoss +from styletts2.Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule from optimizers import build_optimizer