KnowIt.default_archs.TFT#
This module implements a simplified Temporal Fusion Transformer (TFT) architecture.
Unlike the original TFT, this version omits the encoding of exogenous inputs (e.g. static covariates or future known inputs). It is tailored for regression, classification, or variable length regression tasks, using temporal variable selection, gating mechanisms, LSTM-based encoding, and Interpretable Multi-Head Attention.
The following diagram depicts the overall architecture:
X → [EmbeddingLayer] → [VariableSelectionNetwork] → [LSTM]* → [GateAddNorm] → [InterpretableMultiHeadAttention]* → [GateAddNorm] → [GatedResidualNetwork] → [Dense] → Y
“*” indicates a skip connection bypassing this module from the past to the next.
See: [1] “Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting”
Lim et al., 2019 — https://arxiv.org/abs/1912.09363
Some inspiration for this implementation has been taken from pytorch-forecasting’s TemporalFusionTransformer <https://pytorch-forecasting.readthedocs.io/en/latest/models.html>.
Note: The LSTM stage of this TFT makes use of the default LSTMv2 architecture within KnowIt. All stateful processing is handled by this underlying LSTM module.
EmbeddingLayer#
Performs the initial embedding of input features into a representation space for further processing. Linear embeddings are done in one of two modes:
independent (default): input components are independently embedded,
mixed: the embedding of each input component depends on all input components.
Gate#
→ [Dropout] → [Linear] → [GLU] →
The main gating mechanism in the TFT. It consists of a linear layer and a Gated Linear Unit (GLU)(https://docs.pytorch.org/docs/stable/generated/torch.nn.GLU.html) function, with optional dropout before the linear layer.
This module is used as a basic building block by other modules in the architecture. It allows the model to “select” what submodules are useful for the current task.
GateAddNorm#
- → → [Gate] → [LayerNorm] →
→ → → ↑
Applies the gating mechanism and layer normalization. It also allows a residual connection from before this module, which is added to the output of the Gate.
This module is used as both a building block and as part of the main architectural flow. Specifically, it is used after the LSTM encoder and InterpretableMultiHeadAttention blocks to skip over these modules and dynamically calibrate the complexity of the overall architecture.
GatedResidualNetwork#
- → → [Linear] → [ELU] → [Linear] → [GateAddNorm] →
↓ → → → → → → → → → → → → → → → ↑
- A gated feedforward block:
Enables non-linear transformations with gating control.
Residual connection allows gradient flow and feature reuse.
This module is used as both a building block and as part of the main architectural flow. Specifically, it is used towards the end of the architecture as a final feedforward stage.
VariableSelectionNetwork#
- → → → [GatedResidualNetwork] → → → → → → → → [Weighted sum] →
↓ → [GatedResidualNetwork] → [Softmax] → → → → → ↑
This module performs a feature selection step before passing on the input to the LSTM encoder. Specifically, a weighted sum of (processed by GatedResidualNetwork) input values are returned, where the weights are determined by a variable selection mechanism done by a different dedicated GatedResidualNetwork and a softmax function.
ScaledDotProductAttention#
Performs basic dot product self attention. Optionally incorporates attention dropout and causality masking.
InterpretableMultiHeadAttention#
Performs interpretable multi-head attention, where the query and key operations are done for each head separately, but heads share a value component. This is said to improve the interpretability of the resulting attention weights.
While the LSTM is meant to model local information (and encode positional information) the InterpretableMultiHeadAttention module is intended to capture long-term dependencies in the data.
Classes#
Temporal Fusion Transformer (TFT)-style model. |
|
A gating mechanism using Gated Linear Unit (GLU) activations with optional dropout. |
|
Gates the input (using Gate), implements a residual connection, and applies layer normalization. |
|
A Gated Residual Network (GRN) block for gated non-linear processing. |
|
Embedding layer supporting independent or fully mixed feature embedding. |
|
A Variable Selection Network (VSN) for selecting features after embedding. |
|
A simple scaled dot-product attention mechanism. |
|
An interpretable multi-head attention block to capture long-term dependencies. |
Functions#
|
Initializes the parameters of the given module using suitable initialization schemes. |
|
Fetch output activation function from Pytorch. |
Module Contents#
- class KnowIt.default_archs.TFT.Model(input_dim, output_dim, task_name, *, embedding_mode='independent', hidden_dim=32, lstm_depth=4, lstm_width=64, lstm_hc_init_method='zeros', lstm_layernorm=True, lstm_bidirectional=False, num_attention_heads=4, dropout=0.2, output_activation=None, lstm_stateful=True)#
Bases:
torch.nn.ModuleTemporal Fusion Transformer (TFT)-style model.
This module implements a streamlined Temporal Fusion Transformer architecture consisting of:
Per-feature embedding via
EmbeddingLayerFeature-wise selection through a
VariableSelectionNetworkRecurrent sequence modeling using a custom
LSTMencoderStatic skip connections using
GateAddNormInterpretable multi-head self-attention
Decoder refinement via
GatedResidualNetworkTask-dependent output projection layer
The model supports standard regression, variable-length regression (
'vl_regression'), and classification tasks.- Parameters:
input_dim (
list[int],shape=[in_chunk_size,in_components]) – The shape of the input data. The “time axis” is along the first dimension. If variable length data is processed, then in_chunk_size must be 1.output_dim (
list[int],shape=[out_chunk_size,out_components]) – The shape of the output data. The “time axis” is along the first dimension. If variable length data is processed, then out_chunk_size must be 1.task_name (
str) – The type of task (‘classification’, ‘regression’, or ‘vl_regression’).embedding_mode (
str, default'independent') – Embedding strategy. Must be either “independent” or “mixed”. The former will embed input components independently and the latter will mix information during embedding.hidden_dim (
int, default32) – Hidden dimensionality used throughout embeddings, attention, and gated residual blocks.lstm_depth (
int, default4) – Number of stacked LSTM layers in the encoder. See LSTMv2 for details.lstm_width (
int, default64) – Internal width of the LSTM layers. See LSTMv2 for details.lstm_hc_init_method (
str, default'zeros') – Initialization strategy for LSTM hidden and cell states. See LSTMv2 for details.lstm_layernorm (
bool, defaultTrue) – If True, applies layer normalization inside the LSTM. See LSTMv2 for details.lstm_bidirectional (
bool, defaultFalse) – If True, uses a bidirectional LSTM encoder. See LSTMv2 for details.num_attention_heads (
int, default4) – Number of heads in the interpretable multi-head attention module.dropout (
floatorNone, default0.2) – Dropout probability applied throughout the network. If None, dropout is disabled.output_activation (
strorNone, defaultNone) – Optional activation applied after the final linear output layer (only used for non-'vl_regression'tasks).lstm_stateful (
bool, defaultTrue) – If True, the LSTM maintains internal states across forward passes.
- Variables:
embedder (
EmbeddingLayer) – Per-variable embedding module.vsn (
VariableSelectionNetwork) – Learns dynamic feature importance weights.lstm_encoder (
LSTMV2) – Recurrent sequence encoder.attention (
InterpretableMultiHeadAttention) – Self-attention mechanism operating over temporal dimension.decoder_grn (
GatedResidualNetwork) – Decoder refinement block.output_layers (
torch.nn.Module) – Final task-dependent projection layers.
- forward(x, *internal_states)#
Forward pass of the TFT model.
The computation pipeline consists of:
Feature-wise embedding
Variable selection
LSTM encoding
Residual gating
Interpretable multi-head attention
Decoder gated residual refinement
Task-specific output projection
If multiple
internal_statesare provided, they overwrite the internal states of the stateful LSTM before processing.- Parameters:
x (
torch.Tensor) – Input tensor of shape (batch_size, time_steps, n_features).*internal_states (
tupleoftorch.Tensor, optional) – Optional hidden and cell states used to overwrite the internal LSTM states before forward propagation. Only used when the encoder is stateful.
- Returns:
Model output tensor.
For
'regression': Shape (batch_size, output_dim[0], output_dim[1])For
'vl_regression': Shape (batch_size, time_steps, model_out_dim)For
'classification': Shape (batch_size, model_out_dim)
- Return type:
torch.Tensor- Raises:
SystemExit – If an unsupported
task_nameis provided.
- force_reset()#
Wrapper for the underlying LSTM’s corresponding function.
- get_internal_states()#
Wrapper for the underlying LSTM’s corresponding function.
- hard_set_states(ist_idx)#
Wrapper for the underlying LSTM’s corresponding function.
- update_states(ist_idx, device)#
Wrapper for the underlying LSTM’s corresponding function.
- class KnowIt.default_archs.TFT.Gate(n_inputs, n_outputs=None, dropout=None)#
Bases:
torch.nn.ModuleA gating mechanism using Gated Linear Unit (GLU) activations with optional dropout.
This module applies a linear transformation to the input, doubling the output features to produce gates and values, then uses the GLU activation function to gate the output. Dropout is applied before the linear layer if specified.
- Parameters:
n_inputs (
int) – Number of input features.n_outputs (
int,optional (default=None)) – Number of output features. If None, them the number of outputs will be the number of inputs.dropout (
float,optional (default=None)) – Dropout probability applied before the linear transformation. If None, no dropout is applied.
- Variables:
fc (
Linear) – Linear layer producing gates and values with output dimension n_outputs * 2.dropout (
DropoutorNone) – Dropout layer applied before linear transformation.
- forward(x)#
Applies dropout (if any), linear transformation, and GLU gating to the input tensor.
- forward(x)#
Forward pass of the Gating mechanism.
- Parameters:
x (
Tensorofshape (batch_size,sequence_length,n_input)) – Input tensor.- Returns:
Output tensor after gating, where depth is half the output size of the linear layer.
- Return type:
Tensorofshape (batch_size,sequence_length,depth)
- class KnowIt.default_archs.TFT.GateAddNorm(n_inputs, n_outputs=None, n_residuals=None, dropout=None)#
Bases:
torch.nn.ModuleGates the input (using Gate), implements a residual connection, and applies layer normalization.
This module applies a Gate transformation to the input tensor, adds a residual connection (projected if necessary to match dimensions), and normalizes the result using LayerNorm.
- Parameters:
n_inputs (
int) – Number of input features (last dimension of x).n_outputs (
int, optional) – Number of output features. Defaults to n_input.n_residuals (
int, optional) – Dimensionality of the residual input. If different from n_outputs, a linear projection is applied to align dimensions. Defaults to n_outputs.dropout (
float, optional) – Dropout rate applied inside the GLU. If None, no dropout is applied.
- Variables:
gate (
Gate) – The Gate transformation module.norm (
LayerNorm) – Layer normalization applied after residual addition.
- forward(x, residual_connect)#
Applies Gate, adds residual connection (with projection if needed), and normalizes output.
- forward(x, residual_connect)#
Forward pass applying Gate transformation, residual addition, and normalization.
- Parameters:
x (
Tensorofshape (batch_size,sequence_length,n_input)) – Input tensor to transform.residual_connect (
Tensorofshape (batch_size,sequence_length,n_residuals)) – Residual tensor to add after transformation. Projected if necessary.
- Returns:
The output tensor after gated transformation, residual addition, and layer normalization.
- Return type:
Tensorofshape (batch_size,sequence_length,n_outputs)
- class KnowIt.default_archs.TFT.GatedResidualNetwork(n_inputs, hidden_dim, n_outputs, dropout=None)#
Bases:
torch.nn.ModuleA Gated Residual Network (GRN) block for gated non-linear processing.
This module implements a Gated Residual Network (GRN) block that consists of two linear layers with ELU activations inbetween, followed by a GateAddNorm module. There is a residual connection to the GateAddNorm module.
- Parameters:
n_inputs (
int) – Number of input features (last dimension of x).hidden_dim (
int) – Hidden layer dimensionality within the GRN block.n_outputs (
int, optional) – Number of output features. Defaults to n_input.dropout (
float,optional (default=None)) – Dropout rate applied in the gated residual normalization block.
- forward(x, residual_connect=None)#
Forward pass through the GRN, applying linear transformations, activation, gating, residual connection, and normalization.
- forward(x)#
Forward pass through the Gated Residual Network.
- Parameters:
x (
Tensorofshape (batch_size,sequence_length,n_input)) – Input tensor.- Returns:
Output tensor after processing through GRN block with gated residual connection and normalization.
- Return type:
Tensorofshape (batch_size,sequence_length,n_output)
- class KnowIt.default_archs.TFT.EmbeddingLayer(n_inputs, hidden_dim, mode='independent')#
Bases:
torch.nn.ModuleEmbedding layer supporting independent or fully mixed feature embedding.
This module embeds a multivariate time series input of shape (batch_size, time_steps, input_components) into a structured representation of shape
(batch_size, time_steps, input_components, hidden_dim)
Two embedding strategies are supported:
- “independent”
Each input component is embedded separately using its own Linear(1 → hidden_dim) layer. No cross-feature interaction occurs during embedding.
- “mixed”
All input components at each time step are jointly projected using a single Linear(n_inputs → n_inputs * hidden_dim) layer. The output is then reshaped to recover a per-feature embedding structure.
In this mode, each feature embedding depends on the full input feature vector at that time step.
- Parameters:
n_inputs (
int) – Number of input components (features) in the last dimension of the input tensor.hidden_dim (
int) – Dimensionality of the embedding space for each input component.mode (
str, default"independent") – Embedding strategy. Must be either “independent” or “mixed”.
- Variables:
embedders (
nn.ModuleList) – Present when mode=”independent”. Contains n_inputs separate Linear(1, hidden_dim) layers.embedder (
nn.Linear) – Present when mode=”mixed”. A single Linear layer mapping n_inputs → n_inputs * hidden_dim.
Notes
- Input tensor shape:
(batch_size, time_steps, input_components)
- Output tensor shape (both modes):
(batch_size, time_steps, input_components, hidden_dim)
In “mixed” mode, the parameter count scales approximately as:
n_inputs × (n_inputs × hidden_dim)
which grows quadratically with the number of input components.
- forward(x)#
Apply the embedding transformation.
- Parameters:
x (
Tensor) – Input tensor of shape (batch_size, time_steps, input_components).- Returns:
Embedded tensor of shape (batch_size, time_steps, input_components, hidden_dim).
In “independent” mode, each feature is embedded separately.
In “mixed” mode, each feature embedding depends on the entire feature vector at the corresponding time step.
- Return type:
Tensor
- class KnowIt.default_archs.TFT.VariableSelectionNetwork(n_inputs, hidden_dim, dropout=None)#
Bases:
torch.nn.ModuleA Variable Selection Network (VSN) for selecting features after embedding.
Implements a component-wise variable selection mechanism. Each input component is processed by a component-specific nonlinear transformation. A joint gating network computes importance weights across features, which are used to produce a weighted aggregation of feature representations.
- Parameters:
n_inputs (
int) – Number of input components.hidden_dim (
int) – Dimensionality of the shared hidden representation space.dropout (
float, optional) – Dropout rate applied within internal GatedResidualNetwork modules.
Notes
Cross-feature interaction occurs within the variable selection network.
Selection weights are normalized using a softmax over the feature dimension.
The output is a convex combination of processed feature representations.
- forward(x)#
Forward pass of the Variable Selection Network.
- Parameters:
x (
torch.Tensor) – Input tensor of shape (B, T, F, H), where B is the batch size, T is the sequence length, F is the number of input features. H is the number of hidden dimensions.- Returns:
torch.Tensor– Aggregated feature representation of shape (B, T, hidden_dim).output (
Tensor) – Aggregated feature representation of shape (B, T, hidden_dim).variable_selection (
Tensor) – Tensor of shape (B, T, n_inputs) containing the soft attention weights assigned to each input component. Meant for interpretability.
- Return type:
torch.Tensor
Notes
The forward computation consists of:
Independent per-feature projection into the shared representation space.
Computation of context-dependent variable selection weights.
Feature-wise nonlinear processing.
Softmax-weighted aggregation across the feature dimension.
- class KnowIt.default_archs.TFT.ScaledDotProductAttention(attention_dropout=None, scale=True, masking=True)#
Bases:
torch.nn.ModuleA simple scaled dot-product attention mechanism.
This module computes attention scores using the dot product between queries and keys, optionally scales the scores, optionally applies causal masking, and applies a softmax function to obtain attention weights. These weights are used to aggregate values. Optionally includes dropout after softmax.
- Parameters:
attention_dropout (
float, optional) – Dropout probability to apply after the attention softmax. If None, no dropout is used.scale (
bool, optional) – If True, scales attention scores by the square root of the key dimension. Default is True.masking (
bool, optional) – If True, attention values are masked to prevent temporal dimensions from attending to future values. Default True.
- Variables:
attention_dropout (
DropoutorNone) – Dropout layer applied to the attention weights, if specified.scale (
bool) – Whether to apply scaling to the attention logits.masking (
bool) – If True, attention values are masked to prevent temporal dimensions from attending to future values.
- forward(q, k, v)#
Computes attention-weighted values and attention weights from input tensors.
- forward(q, k, v)#
Forward pass of the scaled dot-product attention mechanism.
Computes attention scores as the dot product between queries and keys, optionally scales them, optionally masks for causality, then normalizes with softmax. The resulting attention weights are used to compute a weighted sum over the values.
- Parameters:
q (
Tensor,shape [batch_size,sequence_length,embedding_dim]) – Query tensor.k (
Tensor,shape [batch_size,sequence_length,embedding_dim]) – Key tensor.v (
Tensor,shape [batch_size,sequence_length,embedding_dim]) – Value tensor.
- Returns:
out (
Tensor,shape [batch_size,sequence_length,embedding_dim]) – Resulting attention-weighted sum of values.attn (
Tensor,shape [batch_size,sequence_length,sequence_length]) – Attention weights after softmax.
- Return type:
torch.Tensor
- class KnowIt.default_archs.TFT.InterpretableMultiHeadAttention(n_heads, hidden_dim, dropout=None, masking=True)#
Bases:
torch.nn.ModuleAn interpretable multi-head attention block to capture long-term dependencies.
Implements an interpretable multi-head attention mechanism where each attention head shares the same value vector but uses separate query and key projections.
This variant simplifies interpretability by decoupling the query/key learning across heads while enforcing a shared value representation. Useful in architectures where clarity of attention allocation is important (e.g., temporal attention).
- Parameters:
n_heads (
int) – Number of attention heads.hidden_dim (
int) – Total feature dimension of the input and output.dropout (
float, optional) – Dropout probability applied after attention and after output projection. Default is None.masking (
bool, optional) – If True, attention values are masked to prevent temporal dimensions from attending to future values. Default True.
- Variables:
v_layer (
nn.Linear) – Shared linear projection for values across all heads.q_layer (
nn.ModuleList) – Per-head linear projections for queries.k_layer (
nn.ModuleList) – Per-head linear projections for keys.attention (
ScaledDotProductAttention) – Core attention mechanism computing attention weights and outputs.w_h (
nn.Linear) – Final projection layer mapping averaged head output back to model dimension.dropout (
nn.Dropout) – Dropout probability applied after attention and after output projection.
- forward(x)#
Forward pass of the Interpretable Multi-Head Attention module.
Each attention head uses a separate query and key projection but shares the same value projection across heads. Attention is computed per head, optionally masked for causality, and optionally uses dropout. The outputs of all heads are averaged (not concatenated) to produce the final attention output, which is then projected back to the model dimension and passed through a final dropout.
- Parameters:
x (
Tensor,shape [batch_size,sequence_length,hidden_dim]) – Input tensor containing the sequence to attend over. Each element in the sequence is a hidden vector of dimension hidden_dim.- Returns:
out (
Tensor,shape [batch_size,sequence_length,hidden_dim]) – The attention-weighted output sequence after averaging across heads and applying the final linear projection.attn (
Tensor,shape [batch_size,sequence_length,sequence_length,n_heads]) – Attention weights for each head. Each slice attn[:, :, :, i] corresponds to the attention map of head i, after softmax and optional dropout. If n_heads == 1, the head dimension is omitted.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
Notes
Causal masking ensures that each position can only attend to previous and current time steps when masking=True.
The value projection is shared across all heads to enforce interpretability, following the Temporal Fusion Transformer design.
- KnowIt.default_archs.TFT.init_mod(mod)#
Initializes the parameters of the given module using suitable initialization schemes.
This function iterates over the named parameters of the provided module and applies: - Kaiming uniform initialization for parameters containing ‘weight’ in their name, if applicable. - Standard normal initialization for ‘weight’ parameters where Kaiming initialization is unsuitable. - Zero initialization for parameters containing ‘bias’ in their name.
- Parameters:
mod (
nn.Module) – The PyTorch module whose parameters will be initialized.
Notes
This function is used to prepare layers for training by setting their initial weights and biases to suitable values, which can improve convergence rates.
- KnowIt.default_archs.TFT.get_output_activation(output_activation)#
Fetch output activation function from Pytorch.