Source code for parakeet.models.transformer_tts

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from tqdm import trange

import parakeet
from parakeet.modules import losses as L
from parakeet.modules import masking
from parakeet.modules import positional_encoding as pe
from parakeet.modules.attention import _concat_heads
from parakeet.modules.attention import _split_heads
from parakeet.modules.attention import drop_head
from parakeet.modules.attention import scaled_dot_product_attention
from parakeet.modules.conv import Conv1dBatchNorm
from parakeet.modules.transformer import PositionwiseFFN
from parakeet.utils import checkpoint
from parakeet.utils import scheduler

__all__ = ["TransformerTTS", "TransformerTTSLoss"]


# Transformer TTS's own implementation of transformer
class MultiheadAttention(nn.Layer):
    """Multihead scaled dot product attention with drop head. See 
    [Scheduled DropHead: A Regularization Method for Transformer Models](https://arxiv.org/abs/2004.13342) 
    for details.
    
    Another deviation is that it concats the input query and context vector before
    applying the output projection.
    """

    def __init__(self,
                 model_dim,
                 num_heads,
                 k_dim=None,
                 v_dim=None,
                 k_input_dim=None,
                 v_input_dim=None):
        """
        Args:
            model_dim (int): the feature size of query.
            num_heads (int): the number of attention heads.
            k_dim (int, optional): feature size of the key of each scaled dot 
                product attention. If not provided, it is set to 
                model_dim / num_heads. Defaults to None.
            v_dim (int, optional): feature size of the key of each scaled dot 
                product attention. If not provided, it is set to 
                model_dim / num_heads. Defaults to None.

        Raises:
            ValueError: if model_dim is not divisible by num_heads
        """
        super(MultiheadAttention, self).__init__()
        if model_dim % num_heads != 0:
            raise ValueError("model_dim must be divisible by num_heads")
        depth = model_dim // num_heads
        k_dim = k_dim or depth
        v_dim = v_dim or depth
        k_input_dim = k_input_dim or model_dim
        v_input_dim = v_input_dim or model_dim
        self.affine_q = nn.Linear(model_dim, num_heads * k_dim)
        self.affine_k = nn.Linear(k_input_dim, num_heads * k_dim)
        self.affine_v = nn.Linear(v_input_dim, num_heads * v_dim)
        self.affine_o = nn.Linear(model_dim + num_heads * v_dim, model_dim)

        self.num_heads = num_heads
        self.model_dim = model_dim

    def forward(self, q, k, v, mask, drop_n_heads=0):
        """
        Compute context vector and attention weights.
        
        Args:
            q (Tensor): shape(batch_size, time_steps_q, model_dim), the queries.
            k (Tensor): shape(batch_size, time_steps_k, model_dim), the keys.
            v (Tensor): shape(batch_size, time_steps_k, model_dim), the values.
            mask (Tensor): shape(batch_size, times_steps_q, time_steps_k) or 
                broadcastable shape, dtype: float32 or float64, the mask.

        Returns:
            out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
            attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
        """
        q_in = q
        q = _split_heads(self.affine_q(q), self.num_heads)  # (B, h, T, C)
        k = _split_heads(self.affine_k(k), self.num_heads)
        v = _split_heads(self.affine_v(v), self.num_heads)
        if mask is not None:
            mask = paddle.unsqueeze(mask, 1)  # unsqueeze for the h dim

        context_vectors, attention_weights = scaled_dot_product_attention(
            q, k, v, mask, training=self.training)
        context_vectors = drop_head(context_vectors, drop_n_heads,
                                    self.training)
        context_vectors = _concat_heads(context_vectors)  # (B, T, h*C)

        concat_feature = paddle.concat([q_in, context_vectors], -1)
        out = self.affine_o(concat_feature)
        return out, attention_weights


class TransformerEncoderLayer(nn.Layer):
    """
    Transformer encoder layer.
    """

    def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
        """
        Args:
            d_model (int): the feature size of the input, and the output.
            n_heads (int): the number of heads in the internal MultiHeadAttention layer.
            d_ffn (int): the hidden size of the internal PositionwiseFFN.
            dropout (float, optional): the probability of the dropout in 
                MultiHeadAttention and PositionwiseFFN. Defaults to 0.
        """
        super(TransformerEncoderLayer, self).__init__()
        self.self_mha = MultiheadAttention(d_model, n_heads)
        self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)

        self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
        self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)

        self.dropout = dropout

    def _forward_mha(self, x, mask, drop_n_heads):
        # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
        x_in = x
        x = self.layer_norm1(x)
        context_vector, attn_weights = self.self_mha(x, x, x, mask,
                                                     drop_n_heads)
        context_vector = x_in + F.dropout(
            context_vector, self.dropout, training=self.training)
        return context_vector, attn_weights

    def _forward_ffn(self, x):
        # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
        x_in = x
        x = self.layer_norm2(x)
        x = self.ffn(x)
        out = x_in + F.dropout(x, self.dropout, training=self.training)
        return out

    def forward(self, x, mask, drop_n_heads=0):
        """
        Args:
            x (Tensor): shape(batch_size, time_steps, d_model), the decoder input.
            mask (Tensor): shape(batch_size, 1, time_steps), the padding mask.
        
        Returns:
            x (Tensor): shape(batch_size, time_steps, d_model), the decoded.
            attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention.
        """
        x, attn_weights = self._forward_mha(x, mask, drop_n_heads)
        x = self._forward_ffn(x)
        return x, attn_weights


class TransformerDecoderLayer(nn.Layer):
    """
    Transformer decoder layer.
    """

    def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None):
        """
        Args:
            d_model (int): the feature size of the input, and the output.
            n_heads (int): the number of heads in the internal MultiHeadAttention layer.
            d_ffn (int): the hidden size of the internal PositionwiseFFN.
            dropout (float, optional): the probability of the dropout in 
                MultiHeadAttention and PositionwiseFFN. Defaults to 0.
        """
        super(TransformerDecoderLayer, self).__init__()
        self.self_mha = MultiheadAttention(d_model, n_heads)
        self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)

        self.cross_mha = MultiheadAttention(
            d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder)
        self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)

        self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
        self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)

        self.dropout = dropout

    def _forward_self_mha(self, x, mask, drop_n_heads):
        # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
        x_in = x
        x = self.layer_norm1(x)
        context_vector, attn_weights = self.self_mha(x, x, x, mask,
                                                     drop_n_heads)
        context_vector = x_in + F.dropout(
            context_vector, self.dropout, training=self.training)
        return context_vector, attn_weights

    def _forward_cross_mha(self, q, k, v, mask, drop_n_heads):
        # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
        q_in = q
        q = self.layer_norm2(q)
        context_vector, attn_weights = self.cross_mha(q, k, v, mask,
                                                      drop_n_heads)
        context_vector = q_in + F.dropout(
            context_vector, self.dropout, training=self.training)
        return context_vector, attn_weights

    def _forward_ffn(self, x):
        # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
        x_in = x
        x = self.layer_norm3(x)
        x = self.ffn(x)
        out = x_in + F.dropout(x, self.dropout, training=self.training)
        return out

    def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
        """
        Args:
            q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input.
            k (Tensor): shape(batch_size, time_steps_k, d_model), keys.
            v (Tensor): shape(batch_size, time_steps_k, d_model), values
            encoder_mask (Tensor): shape(batch_size, 1, time_steps_k) encoder padding mask.
            decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) or broadcastable shape, decoder padding mask.
        
        Returns:
            q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded.
            self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
            cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
        """
        q, self_attn_weights = self._forward_self_mha(q, decoder_mask,
                                                      drop_n_heads)
        q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask,
                                                        drop_n_heads)
        q = self._forward_ffn(q)
        return q, self_attn_weights, cross_attn_weights


class TransformerEncoder(nn.LayerList):
    def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.):
        super(TransformerEncoder, self).__init__()
        for _ in range(n_layers):
            self.append(
                TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))

    def forward(self, x, mask, drop_n_heads=0):
        """
        Args:
            x (Tensor): shape(batch_size, time_steps, feature_size), the input tensor.
            mask (Tensor): shape(batch_size, 1, time_steps), the mask.
            drop_n_heads (int, optional): how many heads to drop. Defaults to 0.

        Returns:
            x (Tensor): shape(batch_size, time_steps, feature_size), the context vector.
            attention_weights(list[Tensor]), each of shape
                (batch_size, n_heads, time_steps, time_steps), the attention weights.
        """
        attention_weights = []
        for layer in self:
            x, attention_weights_i = layer(x, mask, drop_n_heads)
            attention_weights.append(attention_weights_i)
        return x, attention_weights


class TransformerDecoder(nn.LayerList):
    def __init__(self,
                 d_model,
                 n_heads,
                 d_ffn,
                 n_layers,
                 dropout=0.,
                 d_encoder=None):
        super(TransformerDecoder, self).__init__()
        for _ in range(n_layers):
            self.append(
                TransformerDecoderLayer(
                    d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))

    def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
        """
        Args:
            q (Tensor): shape(batch_size, time_steps_q, d_model)
            k (Tensor): shape(batch_size, time_steps_k, d_encoder)
            v (Tensor): shape(batch_size, time_steps_k, k_encoder)
            encoder_mask (Tensor): shape(batch_size, 1, time_steps_k)
            decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q)
            drop_n_heads (int, optional): [description]. Defaults to 0.

        Returns:
            q (Tensor): shape(batch_size, time_steps_q, d_model), the output.
            self_attention_weights (List[Tensor]): shape (batch_size, num_heads, encoder_steps, encoder_steps)
            cross_attention_weights (List[Tensor]): shape (batch_size, num_heads, decoder_steps, encoder_steps)
        """
        self_attention_weights = []
        cross_attention_weights = []
        for layer in self:
            q, self_attention_weights_i, cross_attention_weights_i = layer(
                q, k, v, encoder_mask, decoder_mask, drop_n_heads)
            self_attention_weights.append(self_attention_weights_i)
            cross_attention_weights.append(cross_attention_weights_i)
        return q, self_attention_weights, cross_attention_weights


class MLPPreNet(nn.Layer):
    """Decoder's prenet."""

    def __init__(self, d_input, d_hidden, d_output, dropout):
        # (lin + relu + dropout) * n + last projection
        super(MLPPreNet, self).__init__()
        self.lin1 = nn.Linear(d_input, d_hidden)
        self.lin2 = nn.Linear(d_hidden, d_hidden)
        self.lin3 = nn.Linear(d_hidden, d_output)
        self.dropout = dropout

    def forward(self, x, dropout):
        l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=True)
        l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=True)
        l3 = self.lin3(l2)
        return l3


class CNNPostNet(nn.Layer):
    def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
        super(CNNPostNet, self).__init__()
        self.convs = nn.LayerList()
        kernel_size = kernel_size if isinstance(kernel_size, (
            tuple, list)) else (kernel_size, )
        padding = (kernel_size[0] - 1, 0)
        for i in range(n_layers):
            c_in = d_input if i == 0 else d_hidden
            c_out = d_output if i == n_layers - 1 else d_hidden
            self.convs.append(
                Conv1dBatchNorm(
                    c_in,
                    c_out,
                    kernel_size,
                    weight_attr=I.XavierUniform(),
                    padding=padding,
                    momentum=0.99,
                    epsilon=1e-03))
        self.last_bn = nn.BatchNorm1D(d_output, momentum=0.99, epsilon=1e-3)
        # for a layer that ends with a normalization layer that is targeted to
        # output a non zero-central output, it may take a long time to 
        # train the scale and bias
        # NOTE: it can also be a non-causal conv

    def forward(self, x):
        x_in = x
        for i, layer in enumerate(self.convs):
            x = layer(x)
            if i != (len(self.convs) - 1):
                x = F.tanh(x)
        # TODO: check it
        # x = x_in + x
        x = self.last_bn(x_in + x)
        return x


[docs]class TransformerTTS(nn.Layer): def __init__(self, frontend: parakeet.frontend.Phonetics, d_encoder: int, d_decoder: int, d_mel: int, n_heads: int, d_ffn: int, encoder_layers: int, decoder_layers: int, d_prenet: int, d_postnet: int, postnet_layers: int, postnet_kernel_size: int, max_reduction_factor: int, decoder_prenet_dropout: float, dropout: float, n_tones=None): super(TransformerTTS, self).__init__() # text frontend (text normalization and g2p) self.frontend = frontend # encoder self.encoder_prenet = nn.Embedding( frontend.vocab_size, d_encoder, padding_idx=frontend.vocab.padding_index, weight_attr=I.Uniform(-0.05, 0.05)) if n_tones: self.toned = True self.tone_embed = nn.Embedding( n_tones, d_encoder, padding_idx=0, weight_attr=I.Uniform(-0.005, 0.005)) else: self.toned = False # position encoding matrix may be extended later self.encoder_pe = pe.sinusoid_position_encoding(1000, d_encoder) self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.)) self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout) # decoder self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout) self.decoder_pe = pe.sinusoid_position_encoding(1000, d_decoder) self.decoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.)) self.decoder = TransformerDecoder( d_decoder, n_heads, d_ffn, decoder_layers, dropout, d_encoder=d_encoder) self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel) self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers) self.stop_conditioner = nn.Linear(d_mel, 3) # specs self.padding_idx = frontend.vocab.padding_index self.d_encoder = d_encoder self.d_decoder = d_decoder self.d_mel = d_mel self.max_r = max_reduction_factor self.dropout = dropout self.decoder_prenet_dropout = decoder_prenet_dropout # start and end: though it is only used in predict # it can also be used in training dtype = paddle.get_default_dtype() self.start_vec = paddle.full([1, d_mel], 0.5, dtype=dtype) self.end_vec = paddle.full([1, d_mel], -0.5, dtype=dtype) self.stop_prob_index = 2 # mutables self.r = max_reduction_factor # set it every call self.drop_n_heads = 0
[docs] def forward(self, text, mel, tones=None): encoded, encoder_attention_weights, encoder_mask = self.encode( text, tones=tones) mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode( encoded, mel, encoder_mask) outputs = { "mel_output": mel_output, "mel_intermediate": mel_intermediate, "encoder_attention_weights": encoder_attention_weights, "cross_attention_weights": cross_attention_weights, "stop_logits": stop_logits, } return outputs
[docs] def encode(self, text, tones=None): T_enc = text.shape[-1] embed = self.encoder_prenet(text) if self.toned: embed += self.tone_embed(tones) if embed.shape[1] > self.encoder_pe.shape[0]: new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2) self.encoder_pe = pe.sinusoid_position_encoding(new_T, self.d_encoder) pos_enc = self.encoder_pe[:T_enc, :] # (T, C) x = embed.scale( math.sqrt(self.d_encoder)) + pos_enc * self.encoder_pe_scalar x = F.dropout(x, self.dropout, training=self.training) # TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask encoder_padding_mask = paddle.unsqueeze( masking.id_mask(text, self.padding_idx, dtype=x.dtype), 1) x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads) return x, attention_weights, encoder_padding_mask
[docs] def decode(self, encoder_output, input, encoder_padding_mask): batch_size, T_dec, mel_dim = input.shape x = self.decoder_prenet(input, self.decoder_prenet_dropout) # twice its length if needed if x.shape[1] * self.r > self.decoder_pe.shape[0]: new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2) self.decoder_pe = pe.sinusoid_position_encoding(new_T, self.d_decoder) pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :] x = x.scale( math.sqrt(self.d_decoder)) + pos_enc * self.decoder_pe_scalar x = F.dropout(x, self.dropout, training=self.training) no_future_mask = masking.future_mask(T_dec, dtype=input.dtype) decoder_padding_mask = masking.feature_mask( input, axis=-1, dtype=input.dtype) decoder_mask = masking.combine_mask( decoder_padding_mask.unsqueeze(1), no_future_mask) decoder_output, _, cross_attention_weights = self.decoder( x, encoder_output, encoder_output, encoder_padding_mask, decoder_mask, self.drop_n_heads) # use only parts of it output_proj = self.final_proj(decoder_output)[:, :, :self.r * mel_dim] mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim]) stop_logits = self.stop_conditioner(mel_intermediate) # cnn postnet mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1]) mel_output = self.decoder_postnet(mel_channel_first) mel_output = paddle.transpose(mel_output, [0, 2, 1]) return mel_output, mel_intermediate, cross_attention_weights, stop_logits
[docs] @paddle.no_grad() def infer(self, input, max_length=1000, verbose=True, tones=None): """Predict log scale magnitude mel spectrogram from text input. Args: input (Tensor): shape (T), dtype int, input text sequencce. max_length (int, optional): max decoder steps. Defaults to 1000. verbose (bool, optional): display progress bar. Defaults to True. """ decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) # encoder the text sequence encoder_output, encoder_attentions, encoder_padding_mask = self.encode( input, tones=tones) for _ in trange(int(max_length // self.r) + 1): mel_output, _, cross_attention_weights, stop_logits = self.decode( encoder_output, decoder_input, encoder_padding_mask) # extract last step and append it to decoder input decoder_input = paddle.concat( [decoder_input, mel_output[:, -1:, :]], 1) # extract last r steps and append it to decoder output decoder_output = paddle.concat( [decoder_output, mel_output[:, -self.r:, :]], 1) # stop condition: (if any ouput frame of the output multiframes hits the stop condition) # import pdb; pdb.set_trace() if paddle.any( paddle.argmax(stop_logits[0, -self.r:, :], axis=-1) == self.stop_prob_index): if verbose: print("Hits stop condition.") break mel_output = decoder_output[:, 1:, :] outputs = { "mel_output": mel_output, "encoder_attention_weights": encoder_attentions, "cross_attention_weights": cross_attention_weights, } return outputs
[docs] def set_constants(self, reduction_factor, drop_n_heads): self.r = reduction_factor self.drop_n_heads = drop_n_heads
[docs] @classmethod def from_pretrained(cls, frontend, config, checkpoint_path): model = TransformerTTS( frontend, d_encoder=config.model.d_encoder, d_decoder=config.model.d_decoder, d_mel=config.data.n_mels, n_heads=config.model.n_heads, d_ffn=config.model.d_ffn, encoder_layers=config.model.encoder_layers, decoder_layers=config.model.decoder_layers, d_prenet=config.model.d_prenet, d_postnet=config.model.d_postnet, postnet_layers=config.model.postnet_layers, postnet_kernel_size=config.model.postnet_kernel_size, max_reduction_factor=config.model.max_reduction_factor, decoder_prenet_dropout=config.model.decoder_prenet_dropout, dropout=config.model.dropout) iteration = checkpoint.load_parameters( model, checkpoint_path=checkpoint_path) drop_n_heads = scheduler.StepWise(config.training.drop_n_heads) reduction_factor = scheduler.StepWise(config.training.reduction_factor) model.set_constants( reduction_factor=reduction_factor(iteration), drop_n_heads=drop_n_heads(iteration)) return model
[docs]class TransformerTTSLoss(nn.Layer): def __init__(self, stop_loss_scale): super(TransformerTTSLoss, self).__init__() self.stop_loss_scale = stop_loss_scale
[docs] def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs): mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype) mask1 = paddle.unsqueeze(mask, -1) mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1) mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1) mel_len = mask.shape[-1] last_position = F.one_hot( mask.sum(-1).astype("int64") - 1, num_classes=mel_len) mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype( mask.dtype) stop_loss = L.masked_softmax_with_cross_entropy( stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) loss = mel_loss1 + mel_loss2 + stop_loss losses = dict( loss=loss, # total loss mel_loss1=mel_loss1, # ouput mel loss mel_loss2=mel_loss2, # intermediate mel loss stop_loss=stop_loss # stop prob loss ) return losses