parakeet.modules package
Submodules
parakeet.modules.attention module
- class parakeet.modules.attention.LocationSensitiveAttention(d_query: int, d_key: int, d_attention: int, location_filters: int, location_kernel_size: int)[source]
Bases:
paddle.fluid.dygraph.layers.Layer
Location Sensitive Attention module.
Reference: Attention-Based Models for Speech Recognition
- Parameters
- d_query: int
The feature size of query.
- d_keyint
The feature size of key.
- d_attentionint
The feature size of dimension.
- location_filtersint
Filter size of attention convolution.
- location_kernel_sizeint
Kernel size of attention convolution.
- forward(query, processed_key, value, attention_weights_cat, mask=None)[source]
Compute context vector and attention weights.
- Parameters
- queryTensor [shape=(batch_size, d_query)]
The queries.
- processed_keyTensor [shape=(batch_size, time_steps_k, d_attention)]
The keys after linear layer.
- valueTensor [shape=(batch_size, time_steps_k, d_key)]
The values.
- attention_weights_catTensor [shape=(batch_size, time_step_k, 2)]
Attention weights concat.
- maskTensor, optional
The mask. Shape should be (batch_size, times_steps_k, 1). Defaults to None.
- Returns
- attention_contextTensor [shape=(batch_size, d_attention)]
The context vector.
- attention_weightsTensor [shape=(batch_size, time_steps_k)]
The attention weights.
- class parakeet.modules.attention.MonoheadAttention(model_dim: int, dropout: float = 0.0, k_dim: Optional[int] = None, v_dim: Optional[int] = None)[source]
Bases:
paddle.fluid.dygraph.layers.Layer
Monohead Attention module.
- Parameters
- model_dimint
Feature size of the query.
- dropoutfloat, optional
Dropout probability of scaled dot product attention and final context vector. Defaults to 0.0.
- k_dimint, optional
Feature size of the key of each scaled dot product attention. If not provided, it is set to model_dim / num_heads. Defaults to None.
- v_dimint, optional
Feature size of the key of each scaled dot product attention. If not provided, it is set to model_dim / num_heads. Defaults to None.
- forward(q, k, v, mask)[source]
Compute context vector and attention weights.
- Parameters
- qTensor [shape=(batch_size, time_steps_q, model_dim)]
The queries.
- kTensor [shape=(batch_size, time_steps_k, model_dim)]
The keys.
- vTensor [shape=(batch_size, time_steps_k, model_dim)]
The values.
- maskTensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape
The mask.
- Returns
- outTensor [shape=(batch_size, time_steps_q, model_dim)]
The context vector.
- attention_weightsTensor [shape=(batch_size, times_steps_q, time_steps_k)]
The attention weights.
- class parakeet.modules.attention.MultiheadAttention(model_dim: int, num_heads: int, dropout: float = 0.0, k_dim: Optional[int] = None, v_dim: Optional[int] = None)[source]
Bases:
paddle.fluid.dygraph.layers.Layer
Multihead Attention module.
- Parameters
- model_dim: int
The feature size of query.
- num_headsint
The number of attention heads.
- dropoutfloat, optional
Dropout probability of scaled dot product attention and final context vector. Defaults to 0.0.
- k_dimint, optional
Feature size of the key of each scaled dot product attention. If not provided, it is set to
model_dim / num_heads
. Defaults to None.- v_dimint, optional
Feature size of the key of each scaled dot product attention. If not provided, it is set to
model_dim / num_heads
. Defaults to None.
- Raises
- ValueError
If
model_dim
is not divisible bynum_heads
.
- forward(q, k, v, mask)[source]
Compute context vector and attention weights.
- Parameters
- qTensor [shape=(batch_size, time_steps_q, model_dim)]
The queries.
- kTensor [shape=(batch_size, time_steps_k, model_dim)]
The keys.
- vTensor [shape=(batch_size, time_steps_k, model_dim)]
The values.
- maskTensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape
The mask.
- Returns
- outTensor [shape=(batch_size, time_steps_q, model_dim)]
The context vector.
- attention_weightsTensor [shape=(batch_size, times_steps_q, time_steps_k)]
The attention weights.
- parakeet.modules.attention.drop_head(x, drop_n_heads, training=True)[source]
Drop n context vectors from multiple ones.
- Parameters
- xTensor [shape=(batch_size, num_heads, time_steps, channels)]
The input, multiple context vectors.
- drop_n_headsint [0<= drop_n_heads <= num_heads]
Number of vectors to drop.
- trainingbool
A flag indicating whether it is in training. If False, no dropout is applied.
- Returns
- Tensor
The output.
- parakeet.modules.attention.scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True)[source]
Scaled dot product attention with masking.
Assume that q, k, v all have the same leading dimensions (denoted as * in descriptions below). Dropout is applied to attention weights before weighted sum of values.
- Parameters
- qTensor [shape=(*, T_q, d)]
the query tensor.
- kTensor [shape=(*, T_k, d)]
the key tensor.
- vTensor [shape=(*, T_k, d_v)]
the value tensor.
- maskTensor, [shape=(*, T_q, T_k) or broadcastable shape], optional
the mask tensor, zeros correspond to paddings. Defaults to None.
- Returns
- outTensor [shape=(*, T_q, d_v)]
the context vector.
- attn_weightsTensor [shape=(*, T_q, T_k)]
the attention weights.
parakeet.modules.audio module
- class parakeet.modules.audio.MelScale(sr, n_fft, n_mels, fmin, fmax)[source]
Bases:
paddle.fluid.dygraph.layers.Layer
- class parakeet.modules.audio.STFT(n_fft, hop_length=None, win_length=None, window='hanning', center=True, pad_mode='reflect')[source]
Bases:
paddle.fluid.dygraph.layers.Layer
A module for computing stft transformation in a differentiable way.
- Parameters
- n_fftint
Number of samples in a frame.
- hop_lengthint
Number of samples shifted between adjacent frames.
- win_lengthint
Length of the window.
- windowstr, optional
Name of window function, see scipy.signal.get_window for more details. Defaults to “hanning”.
- centerbool
If True, the signal y is padded so that frame D[:, t] is centered at y[t * hop_length]. If False, then D[:, t] begins at y[t * hop_length]. Defaults to True.
- pad_modestring or function
If center=True, this argument is passed to np.pad for padding the edges of the signal y. By default (pad_mode=”reflect”), y is padded on both sides with its own reflection, mirrored around its first and last sample respectively. If center=False, this argument is ignored.
Notes
It behaves like
librosa.core.stft
. Seelibrosa.core.stft
for more details.Given a audio which
T
samples, it the STFT transformation outputs a spectrum with (C, frames) and complex dtype, whereC = 1 + n_fft / 2
andframes = 1 + T // hop_lenghth
.Ony
center
andreflect
padding is supported now.- forward(x)[source]
Compute the stft transform. Parameters ———— x : Tensor [shape=(B, T)]
The input waveform.
- realTensor [shape=(B, C, frames)]
The real part of the spectrogram.
- imagTensor [shape=(B, C, frames)]
The image part of the spectrogram.
- parakeet.modules.audio.dequantize(quantized, n_bands, dtype=None)[source]
Linearlly dequantize an integer Tensor into a float Tensor in the range [-1, 1).
- Parameters
- quantizedTensor [dtype: int]
The quantized value in the range [0, n_bands).
- n_bandsint
Number of bands. The input integer Tensor’s value is in the range [0, n_bans).
- dtypestr, optional
Data type of the output.
- Returns
- Tensor
The dequantized tensor, dtype is specified by dtype. If dtype is not specified, the default float data type is used.
- parakeet.modules.audio.quantize(values, n_bands)[source]
Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands).
- Parameters
- valuesTensor [dtype: flaot32 or float64]
The floating point value.
- n_bandsint
The number of bands. The output integer Tensor’s value is in the range [0, n_bans).
- Returns
- Tensor [dtype: int 64]
The quantized tensor.
parakeet.modules.conv module
- class parakeet.modules.conv.Conv1dBatchNorm(in_channels, out_channels, kernel_size, stride=1, padding=0, weight_attr=None, bias_attr=None, data_format='NCL', momentum=0.9, epsilon=1e-05)[source]
Bases:
paddle.fluid.dygraph.layers.Layer
A Conv1D Layer followed by a BatchNorm1D.
- Parameters
- in_channelsint
The feature size of the input.
- out_channelsint
The feature size of the output.
- kernel_sizeint
The size of the convolution kernel.
- strideint, optional
The stride of the convolution, by default 1.
- paddingint, str or Tuple[int], optional
The padding of the convolution. If int, a symmetrical padding is applied before convolution; If str, it should be “same” or “valid”; If Tuple[int], its length should be 2, meaning
(pad_before, pad_after)
, by default 0.- weight_attrParamAttr, Initializer, str or bool, optional
The parameter attribute of the convolution kernel, by default None.
- bias_attrParamAttr, Initializer, str or bool, optional
The parameter attribute of the bias of the convolution, by default None.
- data_formatstr [“NCL” or “NLC”], optional
The data layout of the input, by default “NCL”
- momentumfloat, optional
The momentum of the BatchNorm1D layer, by default 0.9
- epsilon[type], optional
The epsilon of the BatchNorm1D layer, by default 1e-05
- class parakeet.modules.conv.Conv1dCell(in_channels, out_channels, kernel_size, dilation=1, weight_attr=None, bias_attr=None)[source]
Bases:
paddle.nn.layer.conv.Conv1D
A subclass of Conv1D layer, which can be used in an autoregressive decoder like an RNN cell.
When used in autoregressive decoding, it performs causal temporal convolution incrementally. At each time step, it takes a step input and returns a step output.
- Parameters
- in_channels: int
The feature size of the input.
- out_channels: int
The feature size of the output.
- kernel_size: int or Tuple[int]
The size of the kernel.
- dilation: int or Tuple[int]
The dilation of the convolution, by default 1
- weight_attr: ParamAttr, Initializer, str or bool, optional
The parameter attribute of the convolution kernel, by default None.
- bias_attr: ParamAttr, Initializer, str or bool, optional
The parameter attribute of the bias. If
False
, this layer does not have a bias, by default None.
Notes
It is done by caching an internal buffer of length
receptive_file - 1
. when adding a step input, the buffer is shited by one step, the latest input is added to be buffer and the oldest step is discarded. And it returns a step output. For single step case, convolution is equivalent to a linear transformation. That it can be used as a cell depends on several restrictions: 1. stride must be 1; 2. padding must be a causal padding (recpetive_field - 1, 0). Thus, these arguments are removed from the__init__
method of this class.Examples
>>> cell = Conv1dCell(3, 4, kernel_size=5) >>> inputs = [paddle.randn([4, 3]) for _ in range(16)] >>> outputs = [] >>> cell.eval() >>> cell.start_sequence() >>> for xt in inputs: >>> outputs.append(cell.add_input(xt)) >>> len(outputs)) 16 >>> outputs[0].shape [4, 4]
- add_input(x_t)[source]
Add step input and compute step output.
- Parameters
- x_tTensor [shape=(batch_size, in_channels)]
The step input.
- Returns
- y_t :Tensor [shape=(batch_size, out_channels)]
The step output.
- initialize_buffer(x_t)[source]
Initialize the buffer for the step input.
- Parameters
- x_tTensor [shape=(batch_size, in_channels)]
The step input.
- property receptive_field
The receptive field of the Conv1dCell.
parakeet.modules.geometry module
- parakeet.modules.geometry.shuffle_dim(x, axis, perm=None)[source]
Permute input tensor along aixs given the permutation or randomly.
- Parameters
- xTensor
The input tensor.
- axisint
The axis to shuffle.
- permList[int], ndarray, optional
The order to reorder the tensor along the
axis
-th dimension.It is a permutation of
[0, d)
, where d is the size of theaxis
-th dimension of the input tensor. If not provided, a random permutation is used. Defaults to None.
- Returns
- Tensor
The shuffled tensor, which has the same shape as x does.
parakeet.modules.losses module
- parakeet.modules.losses.guided_attention_loss(attention_weight, dec_lens, enc_lens, g)[source]
Guided attention loss, masked to excluded padding parts.
- parakeet.modules.losses.masked_l1_loss(prediction, target, mask)[source]
Compute maksed L1 loss.
- Parameters
- predictionTensor
The prediction.
- targetTensor
The target. The shape should be broadcastable to
prediction
.- maskTensor
The mask. The shape should be broadcatable to the broadcasted shape of
prediction
andtarget
.
- Returns
- Tensor [shape=(1,)]
The masked L1 loss.
- parakeet.modules.losses.masked_softmax_with_cross_entropy(logits, label, mask, axis=- 1)[source]
Compute masked softmax with cross entropy loss.
- Parameters
- logitsTensor
The logits. The
axis
-th axis is the class dimension.- labelTensor [dtype: int]
The label. The size of the
axis
-th axis should be 1.- maskTensor
The mask. The shape should be broadcastable to
label
.- axisint, optional
The index of the class dimension in the shape of
logits
, by default -1.
- Returns
- Tensor [shape=(1,)]
The masked softmax with cross entropy loss.
- parakeet.modules.losses.weighted_mean(input, weight)[source]
Weighted mean. It can also be used as masked mean.
- Parameters
- inputTensor
The input tensor.
- weightTensor
The weight tensor with broadcastable shape with the input.
- Returns
- Tensor [shape=(1,)]
Weighted mean tensor with the same dtype as input.
parakeet.modules.masking module
- parakeet.modules.masking.combine_mask(mask1, mask2)[source]
Combine two mask with multiplication or logical and.
- Parameters
- mask1Tensor
The first mask.
- mask2Tensor
The second mask with broadcastable shape with
mask1
.- Returns
- ——–
- Tensor
Combined mask.
Notes
It is mainly used to combine the padding mask and no future mask for transformer decoder.
Padding mask is used to mask padding positions of the decoder inputs and no future mask is used to prevent the decoder to see future information.
- parakeet.modules.masking.feature_mask(input, axis, dtype='bool')[source]
Compute mask from input features.
For a input features, represented as batched feature vectors, those vectors which all zeros are considerd padding vectors.
- Parameters
- inputTensor [dtype: float]
The input tensor which represents featues.
- axisint
The index of the feature dimension in
input
. Other dimensions are consideredspatial
dimensions.- dtypestr, optional
Data type of the generated mask, by default “bool”
- Returns
- ——-
- Tensor
The geenrated mask with
spatial
shape as mentioned above.It has one less dimension than
input
does.
- parakeet.modules.masking.future_mask(time_steps, dtype='bool')[source]
Generate lower triangular mask.
It is used at transformer decoder to prevent the decoder to see future information.
- Parameters
- time_stepsint
Decoder time steps.
- dtypestr, optional
The data type of the generate mask, by default “bool”.
- Returns
- Tensor
The generated mask.
- parakeet.modules.masking.id_mask(input, padding_index=0, dtype='bool')[source]
Generate mask with input ids.
Those positions where the value equals
padding_index
correspond to 0 orFalse
, otherwise, 1 orTrue
.- Parameters
- inputTensor [dtype: int]
The input tensor. It represents the ids.
- padding_indexint, optional
The id which represents padding, by default 0.
- dtypestr, optional
Data type of the returned mask, by default “bool”.
- Returns
- Tensor
The generate mask. It has the same shape as
input
does.
parakeet.modules.positional_encoding module
parakeet.modules.transformer module
- class parakeet.modules.transformer.PositionwiseFFN(input_size: int, hidden_size: int, dropout=0.0)[source]
Bases:
paddle.fluid.dygraph.layers.Layer
A faithful implementation of Position-wise Feed-Forward Network in Attention is All You Need. It is basically a 2-layer MLP, with relu actication and dropout in between.
- Parameters
- input_size: int
The feature size of the intput. It is also the feature size of the output.
- hidden_size: int
The hidden size.
- dropout: float
The probability of the Dropout applied to the output of the first layer, by default 0.
- class parakeet.modules.transformer.TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout=0.0)[source]
Bases:
paddle.fluid.dygraph.layers.Layer
A faithful implementation of Transformer decoder layer in Attention is All You Need.
- Parameters
- d_model :int
The feature size of the input. It is also the feature size of the output.
- n_headsint
The number of heads of attentions (
MultiheadAttention
layers).- d_ffnint
The hidden size of the positional feed forward network (a
PositionwiseFFN
layer).- dropoutfloat, optional
The probability of the dropout in MultiHeadAttention and PositionwiseFFN, by default 0.
Notes
It uses the PostLN (post layer norm) scheme.
- forward(q, k, v, encoder_mask, decoder_mask)[source]
Forward pass of TransformerEncoderLayer.
- Parameters
- qTensor [shape=(batch_size, time_steps_q, d_model)]
The decoder input.
- kTensor [shape=(batch_size, time_steps_k, d_model)]
The keys.
- vTensor [shape=(batch_size, time_steps_k, d_model)]
The values
- encoder_maskTensor
Encoder padding mask, shape is
(batch_size, time_steps_k, time_steps_k)
or broadcastable shape.- decoder_maskTensor
Decoder mask, shape is
(batch_size, time_steps_q, time_steps_k)
or broadcastable shape.
- Returns
- qTensor [shape=(batch_size, time_steps_q, d_model)]
The decoder output.
- self_attn_weightsTensor [shape=(batch_size, n_heads, time_steps_q, time_steps_q)]
Decoder self attention.
- cross_attn_weightsTensor [shape=(batch_size, n_heads, time_steps_q, time_steps_k)]
Decoder-encoder cross attention.
- class parakeet.modules.transformer.TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout=0.0)[source]
Bases:
paddle.fluid.dygraph.layers.Layer
A faithful implementation of Transformer encoder layer in Attention is All You Need.
- Parameters
- d_model :int
The feature size of the input. It is also the feature size of the output.
- n_headsint
The number of heads of self attention (a
MultiheadAttention
layer).- d_ffnint
The hidden size of the positional feed forward network (a
PositionwiseFFN
layer).- dropoutfloat, optional
The probability of the dropout in MultiHeadAttention and PositionwiseFFN, by default 0.
Notes
It uses the PostLN (post layer norm) scheme.
- forward(x, mask)[source]
Forward pass of TransformerEncoderLayer.
- Parameters
- xTensor [shape=(batch_size, time_steps, d_model)]
The input.
- maskTensor
The padding mask. The shape is (batch_size, time_steps, time_steps) or broadcastable shape.
- Returns
- x :Tensor [shape=(batch_size, time_steps, d_model)]
The encoded output.
- attn_weightsTensor [shape=(batch_size, n_heads, time_steps, time_steps)]
The attention weights of the self attention.