#!/usr/bin/env python
#
# pyLOM - Python Low Order Modeling.
#
# SHRED architecture for NN Module
#
# Williams, J. P., Zahn, O., & Kutz, J. N. (2023). Sensing with shallow recurrent decoder networks. arXiv preprint arXiv:2301.12011.
#
# Last rev: 11/03/2025
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from ...utils.cr import cr
from .encoders_decoders import ShallowDecoder
from ..utils import Dataset
[docs]
class SHRED(nn.Module):
r'''
Shallow recurrent decoder (SHRED) architecture. For more information on the theoretical background of the architecture check the following reference
Williams, J. P., Zahn, O., & Kutz, J. N. (2023). Sensing with shallow recurrent decoder networks. arXiv preprint arXiv:2301.12011.
The model is based on the PyTorch library `torch.nn` (detailed documentation can be found at https://pytorch.org/docs/stable/nn.html).
In this implementation we assume that the output are always the POD coefficients of the full dataset.
Args:
output_size (int): Number of POD modes.
device (torch.device): Device to use.
total_sensors (int): Total number of sensors that will be used to ensamble the different configurations.
hidden_size (int, optional): Dimension of the LSTM hidden layers (default: ``64``).
hidden_layers (int, optional): Number of LSTM hidden layers (default: ``2``).
decoder_sizes (list, optional): Integer list of the decoder layer sizes (default: ``[350, 400]``).
input_size (int, optional): Number of sensor signals used as input (default: ``3``).
dropouts (float, optional): Dropout probability for the decoder (default: ``0.1``).
nconfigs (int, optional): Number of configurations to train SHRED on (default: ``1``).
compile (bool, optional): Flag to compile the model (default: ``False``).
seed (int, optional): Seed for reproducibility (default: ``-1``).
'''
def __init__(
self,
output_size:int,
device:torch.device,
total_sensors:int,
hidden_size:int=64,
hidden_layers:int=2,
decoder_sizes:list=[350, 400],
input_size:int=3,
dropout:int=0.1,
nconfigs:int=1,
compile:bool=False,
seed:int=-1):
super(SHRED,self).__init__()
np.random.seed(0) if seed == -1 else np.random.seed(seed)
if compile:
self.lstm = torch.compile(nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=hidden_layers, batch_first=True), mode="max-autotune")
self.decoder = torch.compile(ShallowDecoder(output_size, hidden_size, decoder_sizes, dropout), mode="max-autotune")
else:
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=hidden_layers, batch_first=True)
self.decoder = ShallowDecoder(output_size, hidden_size, decoder_sizes, dropout)
self.sensxconfig = input_size
self.nconfigs = nconfigs
self.hidden_layers = hidden_layers
self.hidden_size = hidden_size
self.configs = np.zeros((self.nconfigs, self.sensxconfig), dtype=int)
for kk in range(self.nconfigs):
self.configs[kk,:] = np.random.choice(total_sensors, size=self.sensxconfig, replace=False)
self.device = device
self.to(device)
[docs]
def forward(self, x:torch.Tensor):
r'''
Do a forward evaluation of the data.
Args:
x (torch.Tensor): input data to the neural network.
Returns:
(torch.Tensor): Prediction of the neural network.
'''
_, (output, _) = self.lstm(x)
output = output[-1].view(-1, self.hidden_size)
return self.decoder(output)
[docs]
def freeze(self):
r'''
Freeze the model parameters to set it on inference mode.
'''
self.eval()
for param in self.parameters():
param.requires_grad = False
[docs]
def unfreeze(self):
r'''
Unfreeze the model parameters to set it on training mode.
'''
self.train()
for param in self.parameters():
param.requires_grad = True
def _loss_func(self, x:torch.Tensor, recon_x:torch.Tensor, mod_scale:torch.Tensor, reduction:str):
r'''
Model loss function.
Args:
x (torch.Tensor): correct output.
recon_x (torch.Tensor): neural network output.
mod_scale (torch.Tensor): scaling of each POD coefficient according to its energy.
reduction (str): type of reduction applied when doing the MSE.
Returns:
(double): Loss function
'''
return F.mse_loss(x*mod_scale, recon_x*mod_scale, reduction=reduction)
def _mre(self, x:torch.Tensor, recon_x:torch.Tensor, mod_scale:torch.Tensor):
r'''
Mean relative error between the original and the SHRED reconstruction.
Args:
x (torch.Tensor): correct output.
recon_x (torch.Tensor): neural network output.
mod_scale (torch.Tensor): scaling of each POD coefficient according to its energy.
Returns:
(double): Mean relative error
'''
diff = (x-recon_x)*(x-recon_x)
num = torch.sqrt(torch.sum(diff, axis=0))
den = torch.sqrt(torch.sum(x*x, axis=0))
return torch.sum(num/den*mod_scale/len(mod_scale))
[docs]
@cr('SHRED.fit')
def fit(self, train_dataset: Dataset, valid_dataset: Dataset, batch_size:int=64, epochs:int=4000, optim:torch.optim.Optimizer=torch.optim.Adam, lr:float=1e-3, reduction:str='mean', verbose:bool=False, patience:int=5, mod_scale:torch.Tensor=None):
r'''
Fit of the SHRED model.
Args:
train_dataset (torch.utils.data.Dataset): training dataset.
valid_dataset (torch.utils.data.Dataset): validation dataset.
batch_size (int, optional): length of each training batch (default: ``64``).
epochs (int, optional): number of epochs to extend the training (default: ``4000``).
optim (torch.optim, optional): optimizer used (default: ``torch.optim.Adam``).
lr (float, optional): learning rate (default: ``0.001``).
verbose (bool, optional): define level of explicity on the output (default: ``False``).
patience (int, optional): epochs without improvements on the validation loss before stopping the training (default to 5).
'''
train_dataset.variables_in = train_dataset.variables_in.permute(1,2,0).to(self.device)
valid_dataset.variables_in = valid_dataset.variables_in.permute(1,2,0).to(self.device)
train_dataset.variables_out = train_dataset.variables_out.to(self.device)
valid_dataset.variables_out = valid_dataset.variables_out.to(self.device)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
optimizer = optim(self.parameters(), lr = lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=lr*1e-4)
valid_error_list = []
patience_counter = 0
best_params = self.state_dict()
mod_scale = torch.ones((train_dataset.variables_out.shape[1],), dtype=torch.float32, device=self.device) if mod_scale == None else mod_scale.to(self.device)
for epoch in range(1, epochs + 1):
for k, data in enumerate(train_loader):
self.train()
outputs = self(data[0])
optimizer.zero_grad()
loss = self._loss_func(outputs, data[1], mod_scale, reduction)
loss.backward()
optimizer.step()
scheduler.step()
self.eval()
with torch.no_grad():
train_error = self._mre(train_dataset.variables_out, self(train_dataset.variables_in), mod_scale)
valid_error = self._mre(valid_dataset.variables_out, self(valid_dataset.variables_in), mod_scale)
valid_error_list.append(valid_error)
if verbose == True:
print("Epoch %i : Training loss = %.5e Validation loss = %.5e \r" % (epoch, train_error, valid_error), flush=True)
if valid_error == torch.min(torch.tensor(valid_error_list)):
patience_counter = 0
best_params = self.state_dict().copy()
else:
patience_counter += 1
if patience_counter == patience:
break
self.load_state_dict(best_params)
train_error = self._mre(train_dataset.variables_out, self(train_dataset.variables_in), mod_scale)
valid_error = self._mre(valid_dataset.variables_out, self(valid_dataset.variables_in), mod_scale)
print("Training done: Training loss = %.2f Validation loss = %.2f \r" % (train_error*100, valid_error*100), flush=True)
[docs]
def save(self, path:str, scaler_path:str, podscale_path:str, sensors:np.array):
r'''
Save a SHRED configuration to a .pth file.
Args:
path (str): where the model will be saved.
scaler_path (str): path to the scaler used to scale the sensor data.
podscale_path (str): path to the scaler used for the POD coefficients.
sensors (np.array): IDs of the sensors used for the current SHRED configuration.
'''
torch.save({
'model_state_dict': self.state_dict(),
'scaler_path' : scaler_path,
'podscale_path' : podscale_path,
'sensors' : sensors,}, "%s.pth" % path)