#!/usr/bin/env python
#
# pyLOM - Python Low Order Modeling.
#
# Autoencoder architecture for NN Module
#
# Last rev: 09/10/2024
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from functools import reduce
from operator import mul
from .. import DEVICE
from ...utils import cr, pprint
## Wrapper of a variational autoencoder
[docs]
class Autoencoder(nn.Module):
r"""
Autoencoder class for neural network module. The model is based on the PyTorch.
Args:
latent_dim (int): Dimension of the latent space.
in_shape (tuple): Shape of the input data.
input_channels (int): Number of input channels.
encoder (torch.nn.Module): Encoder model.
decoder (torch.nn.Module): Decoder model.
device (str): Device to run the model. Default is 'cuda' if available, otherwise 'cpu'.
"""
def __init__(
self,
latent_dim: int,
in_shape: tuple,
input_channels: int,
encoder: nn.Module,
decoder: nn.Module,
device: torch.device = DEVICE,
):
super(Autoencoder, self).__init__()
self.lat_dim = latent_dim
self.in_shape = in_shape
self.inp_chan = input_channels
self.N = reduce(mul, in_shape)
self.encoder = encoder
self.decoder = decoder
self._device = device
encoder.to(self._device)
decoder.to(self._device)
self.to(self._device)
summary(self, input_size=(self.inp_chan, *self.in_shape),device=device)
def _lossfunc(self, x, recon_x, reduction):
return F.mse_loss(recon_x.view(-1, self.N), x.view(-1, self.N),reduction=reduction)
[docs]
def forward(self, x):
z = self.encoder(x)
recon = self.decoder(z)
return recon, z
[docs]
def fit(
self,
train_dataset: torch.utils.data.Dataset,
eval_dataset: torch.utils.data.Dataset = None,
epochs: int = 100,
callback=None,
lr: float = 1e-3,
BASEDIR: str = "./",
reduction: str = "mean",
lr_decay: float = 0.999,
batch_size: int = 32,
shuffle: bool = True,
num_workers: int = 0,
pin_memory: bool = True,
):
r"""
Train the autoencoder model. The logs are stored in the directory specified by BASEDIR with tensorboard format.
Args:
train_dataset (torch.utils.data.Dataset): Training dataset.
eval_dataset (torch.utils.data.Dataset): Evaluation dataset.
epochs (int): Number of epochs to train the model. Default is ``100``.
callback: Callback object. Default is ``None``.
lr (float): Learning rate. Default is ``1e-3``.
BASEDIR (str): Directory to save the model. Default is ``"./"``.
reduction (str): Reduction method for the loss function. Default is ``"mean"``.
lr_decay (float): Learning rate decay. Default is ``0.999``.
batch_size (int): Batch size. Default is ``32``.
shuffle (bool): Whether to shuffle the dataset or not. Default is ``True``.
num_workers (int): Number of workers for the Dataloader. Default is ``0``.
pin_memory (bool): Pin memory for Dataloader. Default is ``True``.
"""
dataloader_params = {
"batch_size": batch_size,
"shuffle": shuffle,
"num_workers": num_workers,
"pin_memory": pin_memory,
}
train_data = DataLoader(train_dataset, **dataloader_params)
eval_data = DataLoader(eval_dataset, **dataloader_params)
# Initialization
prev_train_loss = 1e99
writer = SummaryWriter(BASEDIR)
optimizer = torch.optim.AdamW(self.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_decay)
# Training loop
for epoch in range(epochs):
self.train()
num_batches = 0
tr_loss = 0
for batch0 in train_data:
batch = batch0.to(self._device)
recon, _ = self(batch)
loss = self._lossfunc(batch, recon, reduction)
optimizer.zero_grad()
loss.backward()
optimizer.step()
tr_loss += loss.item()
num_batches += 1
tr_loss /= num_batches
# Validation phase
if eval_dataset is not None:
with torch.no_grad():
val_batches = 0
va_loss = 0
for val_batch0 in eval_data:
val_batch = val_batch0.to(self._device)
val_recon, _ = self(val_batch)
vali_loss = self._lossfunc(val_batch, val_recon, reduction)
va_loss += vali_loss.item()
val_batches += 1
va_loss /= val_batches
# Logging
writer.add_scalar("Loss/train", tr_loss, epoch + 1)
writer.add_scalar("Loss/vali", va_loss, epoch + 1)
# Early stopping
if callback and callback.early_stop(va_loss, prev_train_loss, tr_loss):
print(f'Early Stopper Activated at epoch {epoch}', flush=True)
break
prev_train_loss = tr_loss
print(f'Epoch [{epoch+1} / {epochs}] average training loss: {tr_loss:.5e} | average validation loss: {va_loss:.5e}', flush=True)
# Learning rate scheduling
scheduler.step()
# Cleanup
writer.flush()
writer.close()
torch.save(self.state_dict(), f'{BASEDIR}/model_state.pth')
[docs]
def reconstruct(self, dataset: torch.utils.data.Dataset):
r"""
Reconstruct the dataset using the trained autoencoder model. It prints the energy, mean, and fluctuation of the reconstructed dataset.
Args:
dataset (torch.utils.data.Dataset): Dataset to reconstruct.
Returns:
np.ndarray: Reconstructed dataset.
"""
## Compute reconstruction and its accuracy
num_samples = len(dataset)
ek = np.zeros(num_samples)
mu = np.zeros(num_samples)
si = np.zeros(num_samples)
rec = torch.zeros((self.inp_chan, self.N, num_samples), device=self._device)
loader = torch.utils.data.DataLoader(dataset, batch_size=num_samples, shuffle=False)
with torch.no_grad():
## Energy recovered in reconstruction
for energy_batch in loader:
energy_batch = energy_batch.to(self._device)
x_recon,_ = self(energy_batch)
for i in range(num_samples):
x_recchan = x_recon[i]
rec[:, :, i] = x_recchan.view(self.inp_chan, self.N)
x = energy_batch[i].view(self.inp_chan * self.N)
xr = rec[:, :, i].view(self.inp_chan * self.N)
ek[i] = torch.sum((x - xr) ** 2) / torch.sum(x ** 2)
mu[i] = 2 * torch.mean(x) * torch.mean(xr) / (torch.mean(x) ** 2 + torch.mean(xr) ** 2)
si[i] = 2 * torch.std(x) * torch.std(xr) / (torch.std(x) ** 2 + torch.std(xr) ** 2)
energy = (1 - np.mean(ek)) * 100
print('Recovered energy %.2f' % energy)
print('Recovered mean %.2f' % (np.mean(mu) * 100))
print('Recovered fluct %.2f' % (np.mean(si) * 100))
return rec.cpu().numpy()
[docs]
def latent_space(self, dataset: torch.utils.data.Dataset):
r"""
Compute the latent space of the elements of a given dataset.
Args:
dataset (torch.utils.data.Dataset): Dataset to compute the latent space.
Returns:
np.ndarray: Latent space of the dataset elements.
"""
# Compute latent vectors
loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
with torch.no_grad():
instant = iter(loader)
batch = next(instant)
batch = batch.to(self._device)
_,z = self(batch)
return z
[docs]
def decode(self, z):
r"""
Decode the latent space to the original space.
Args:
z (np.ndarray): Element of the latent space.
Returns:
np.ndarray: Decoded latent space.
"""
zt = torch.tensor(z, dtype=torch.float32)
var = self.decoder(zt)
var = var.cpu()
varr = np.zeros((self.N,var.shape[0]),dtype=float)
for it in range(var.shape[0]):
varaux = var[it,0,:,:].detach().numpy()
varr[:,it] = varaux.reshape((self.N,), order='C')
return varr
## Wrapper of a variational autoencoder
[docs]
class VariationalAutoencoder(Autoencoder):
r"""
Variational Autoencoder class for neural network module. The model is based on the PyTorch.
Args:
latent_dim (int): Dimension of the latent space.
in_shape (tuple): Shape of the input data.
input_channels (int): Number of input channels.
encoder (torch.nn.Module): Encoder model.
decoder (torch.nn.Module): Decoder model.
device (str): Device to run the model. Default is 'cuda' if available, otherwise 'cpu'.
"""
def __init__(self, latent_dim, in_shape, input_channels, encoder, decoder, device=DEVICE):
super(VariationalAutoencoder, self).__init__(latent_dim, in_shape, input_channels, encoder, decoder, device)
def _reparamatrizate(self, mu, logvar):
std = torch.exp(0.5*logvar)
epsilon = torch.randn_like(std) #we create a normal distribution (0 ,1 ) with the dimensions of std
sample = mu + std*epsilon
return sample
def _kld(self, mu, logvar):
mum = torch.mean(mu, axis=0)
logvarm = torch.mean(logvar, axis=0)
return 0.5*torch.sum(1 + logvar - mum**2 - logvarm.exp())
[docs]
def forward(self, x):
mu, logvar = self.encoder(x)
z = self._reparamatrizate(mu, logvar)
recon = self.decoder(z)
return recon, mu, logvar, z
[docs]
@cr('VAE.fit')
def fit(
self,
train_dataset,
eval_dataset=None,
betasch=None,
epochs=1000,
callback=None,
lr=1e-4,
BASEDIR="./",
batch_size=32,
shuffle=True,
num_workers=0,
pin_memory=True,
):
r"""
Train the variational autoencoder model. The logs are stored in the directory specified by BASEDIR with tensorboard format.
Args:
train_dataset (torch.utils.data.Dataset): Training dataset.
eval_dataset (torch.utils.data.Dataset): Evaluation dataset.
epochs (int): Number of epochs to train the model. Default is ``100``.
callback: Callback object to change the value of beta during training. Default is ``None``.
lr (float): Learning rate. Default is ``1e-3``.
BASEDIR (str): Directory to save the model. Default is ``"./"``.
reduction (str): Reduction method for the loss function. Default is ``"mean"``.
lr_decay (float): Learning rate decay. Default is ``0.999``.
batch_size (int): Batch size. Default is ``32``.
shuffle (bool): Whether to shuffle the dataset or not. Default is ``True``.
num_workers (int): Number of workers for the Dataloader. Default is ``0``.
pin_memory (bool): Pin memory for Dataloader. Default is ``True``.
"""
dataloader_params = {
"batch_size": batch_size,
"shuffle": shuffle,
"num_workers": num_workers,
"pin_memory": pin_memory,
}
train_data = DataLoader(train_dataset, **dataloader_params)
eval_data = DataLoader(eval_dataset, **dataloader_params)
prev_train_loss = 1e99
writer = SummaryWriter(BASEDIR)
optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=0, amsgrad=False if self._device == "cpu" else True, fused=False if self._device == "cpu" else True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=lr*1e-3)
scaler = GradScaler()
for epoch in range(epochs):
## Training
self.train()
tr_loss = 0
mse = 0
kld = 0
beta = betasch.getBeta(epoch) if betasch is not None else 0
for batch0 in train_data:
batch = batch0.to(self._device)
optimizer.zero_grad()
with autocast(device_type=self._device):
recon, mu, logvar, _ = self(batch)
mse_i = self._lossfunc(batch, recon, reduction='sum')
kld_i = self._kld(mu,logvar)
loss = mse_i - beta*kld_i
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
tr_loss += loss.item()
mse += mse_i.item()
kld += kld_i.item()
num_batches = len(train_data)
tr_loss /= num_batches
mse /= num_batches
kld /= num_batches
## Validation
self.eval()
va_loss = 0
with torch.no_grad():
for val_batch0 in eval_data:
val_batch = val_batch0.to(self._device)
with autocast(device_type=self._device):
val_recon, val_mu, val_logvar, _ = self(val_batch)
mse_i = self._lossfunc(val_batch, val_recon, reduction='sum')
kld_i = self._kld(val_mu,val_logvar)
vali_loss = mse_i - beta*kld_i
va_loss += vali_loss.item()
num_batches = len(eval_data)
va_loss /=num_batches
writer.add_scalar("Loss/train",tr_loss,epoch+1)
writer.add_scalar("Loss/vali", va_loss,epoch+1)
writer.add_scalar("Loss/mse", mse, epoch+1)
writer.add_scalar("Loss/kld", kld, epoch+1)
if callback is not None:
if callback.early_stop(va_loss, prev_train_loss, tr_loss):
pprint(0, 'Early Stopper Activated at epoch %i' %epoch, flush=True)
break
prev_train_loss = tr_loss
pprint(0, 'Epoch [%d / %d] average training loss: %.5e (MSE = %.5e KLD = %.5e) | average validation loss: %.5e' % (epoch+1, epochs, tr_loss, mse, kld, va_loss), flush=True)
# Learning rate scheduling
scheduler.step()
writer.flush()
writer.close()
torch.save(self.state_dict(), '%s/model_state' % BASEDIR)
[docs]
@cr('VAE.reconstruct')
def reconstruct(self, dataset):
r"""
Reconstruct the dataset using the trained variational autoencoder model. It prints the energy, mean, and fluctuation of the reconstructed dataset.
Args:
dataset (torch.utils.data.Dataset): Dataset to reconstruct.
Returns:
np.ndarray: Reconstructed dataset.
"""
## Compute reconstruction and its accuracy
num_samples = len(dataset)
ek = np.zeros(num_samples)
mu = np.zeros(num_samples)
si = np.zeros(num_samples)
rec = torch.zeros((self.inp_chan, self.N, num_samples), device=self._device)
loader = torch.utils.data.DataLoader(dataset, batch_size=num_samples, shuffle=False)
with torch.no_grad():
## Energy recovered in reconstruction
for energy_batch in loader:
energy_batch = energy_batch.to(self._device)
x_recon,_,_,_ = self(energy_batch)
for i in range(num_samples):
x_recchan = x_recon[i]
rec[:, :, i] = x_recchan.view(self.inp_chan, self.N)
x = energy_batch[i].view(self.inp_chan * self.N)
xr = rec[:, :, i].view(self.inp_chan * self.N)
ek[i] = torch.sum((x - xr) ** 2) / torch.sum(x ** 2)
mu[i] = 2 * torch.mean(x) * torch.mean(xr) / (torch.mean(x) ** 2 + torch.mean(xr) ** 2)
si[i] = 2 * torch.std(x) * torch.std(xr) / (torch.std(x) ** 2 + torch.std(xr) ** 2)
energy = (1 - np.mean(ek)) * 100
print('Recovered energy %.2f' % energy)
print('Recovered mean %.2f' % (np.mean(mu) * 100))
print('Recovered fluct %.2f' % (np.mean(si) * 100))
return rec.cpu().numpy()
[docs]
def correlation(self, dataset):
r"""
Compute the correlation between the latent variables of the given dataset.
Args:
dataset (torch.utils.data.Dataset): Dataset to compute the correlation.
Returns:
np.ndarray: Correlation between the latent variables.
"""
## Compute correlation between latent variables
loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
with torch.no_grad():
instant = iter(loader)
batch = next(instant)
batch = batch.to(self._device)
_,_,_, z = self(batch)
np.save('z.npy',z.cpu())
corr = np.corrcoef(z.cpu(),rowvar=False)
detR = np.linalg.det(corr)*100
print('Orthogonality between modes %.2f' % (detR))
return corr, detR#.reshape((self.lat_dim*self.lat_dim,))
[docs]
def modes(self):
r"""
Compute the modes of the latent space.
Returns:
np.ndarray: Modes of the latent space.
"""
zmode = np.diag(np.ones((self.lat_dim,),dtype=float))
zmodt = torch.tensor(zmode, dtype=torch.float32)
zmodt = zmodt.to(self._device)
modes = self.decoder(zmodt)
mymod = np.zeros((self.N,self.lat_dim),dtype=float)
modes = modes.cpu()
for imode in range(self.lat_dim):
modesr = modes[imode,0,:,:].detach().numpy()
mymod[:,imode] = modesr.reshape((self.N,), order='C')
return mymod.reshape((self.N*self.lat_dim,),order='C')
[docs]
def latent_space(self, dataset):
r"""
Compute the latent space of the elements of a given dataset.
Args:
dataset (torch.utils.data.Dataset): Dataset to compute the latent space.
Returns:
np.ndarray: Latent space of the dataset elements.
"""
# Compute latent vectors
loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
with torch.no_grad():
instant = iter(loader)
batch = next(instant)
batch = batch.to(self._device)
_,_,_, z = self(batch)
return z
[docs]
def fine_tune(self, train_dataset, shape_, eval_dataset=None, epochs=1000, callback=None, lr=1e-4, BASEDIR='./', **dataloader_params):
train_data = DataLoader(torch.from_numpy(train_dataset).to(torch.float32), **dataloader_params)
eval_data = DataLoader(torch.from_numpy(eval_dataset).to(torch.float32), **dataloader_params)
prev_train_loss = 1e99
writer = SummaryWriter(BASEDIR)
decoder_model = self.decoder
optimizer = torch.optim.Adam(decoder_model.parameters(), lr=lr, weight_decay=0, amsgrad=False if self._device == "cpu" else True, fused=False if self._device == "cpu" else True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=lr*1e-3)
scaler = GradScaler()
for epoch in range(epochs):
## Training
decoder_model.train()
tr_loss = 0
for batch0 in train_data:
batch = batch0.to(self._device)
optimizer.zero_grad()
with autocast(device_type=self._device):
in_data = batch[:, :self.lat_dim]
recon = decoder_model(in_data)
loss = self._lossfunc(torch.reshape(batch[:, self.lat_dim:], recon.shape), recon, reduction='sum')
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
tr_loss += loss.item()
num_batches = len(train_data)
tr_loss /= num_batches
## Validation
decoder_model.eval()
va_loss = 0
with torch.no_grad():
for val_batch0 in eval_data:
val_batch = val_batch0.to(self._device)
with autocast(device_type=self._device):
val_in_data = val_batch[:, :self.lat_dim]
val_recon = decoder_model(val_in_data)
vali_loss = self._lossfunc(torch.reshape(batch[:, self.lat_dim:], val_recon.shape), val_recon, reduction='sum')
va_loss += vali_loss.item()
num_batches = len(eval_data)
va_loss /=num_batches
writer.add_scalar("Ft/Loss/train",tr_loss,epoch+1)
writer.add_scalar("Ft/Loss/vali", va_loss,epoch+1)
if callback is not None:
if callback.early_stop(va_loss, prev_train_loss, tr_loss):
pprint(0, 'Early Stopper Activated at epoch %i' %epoch, flush=True)
break
prev_train_loss = tr_loss
pprint(0, 'Epoch [%d / %d] average training loss: %.5e | average validation loss: %.5e' % (epoch+1, epochs, tr_loss, va_loss), flush=True)
# Learning rate scheduling
scheduler.step()
writer.flush()
writer.close()
torch.save(decoder_model.state_dict(), '%s/decoder_state' % BASEDIR)
self.decoder.load_state_dict(torch.load('%s/decoder_state' % BASEDIR, weights_only=True))
return 0
[docs]
def decode(self, z):
r"""
Decode a latent space element to the original space.
Args:
z (np.ndarray): Element of the latent space.
Returns:
np.ndarray: Decoded latent space.
"""
zt = torch.tensor(z, dtype=torch.float32)
var = self.decoder(zt)
var = var.cpu()
varr = np.zeros((self.N,var.shape[0]),dtype=float)
for it in range(var.shape[0]):
varaux = var[it,0,:,:].detach().numpy()
varr[:,it] = varaux.reshape((self.N,), order='C')
return varr