#!/usr/bin/env python
#
# pyLOM - Python Low Order Modeling.
#
# NN interpolation routines.
#
# Last rev: 22/05/2025
import torch
from ..dataset import Dataset as pyLOMDataset
from ..utils.errors import raiseError
from .. import pprint
[docs]
class Interpolator():
def __init__(
self,
dataset: pyLOMDataset,
):
self.dataset = dataset
[docs]
@staticmethod
def objective_mse(
field_mod: torch.Tensor,
field_ref: torch.Tensor,
**kwargs: dict,
)-> torch.Tensor:
r"""
Objective function to minimize the difference between the modified and original field.
Args:
field_mod (torch.Tensor): Modified field.
field_ref (torch.Tensor): Original field.
**kwargs: Additional arguments.
Returns:
torch.Tensor: The sum of squared differences between the modified and original field.
"""
mse_loss = torch.nn.MSELoss(reduction='sum')
return mse_loss(field_mod, field_ref)
[docs]
@staticmethod
def multitarget_equality_penalty(
field_mod: torch.Tensor,
penalty_func: callable,
target_names: list,
ref_values: dict,
penalty_args: dict,
**kwargs,
)-> torch.Tensor:
r"""
Multitarget equality penalty function to ensure the modified field matches the reference values.
Args:
field_mod (torch.Tensor): Modified field.
penalty_func (callable): Function to compute the penalty.
target_names (list): List of target names.
ref_values (dict): Dictionary with reference values for each target.
penalty_args (dict): Additional arguments for the penalty function.
**kwargs: Additional arguments.
Returns:
torch.Tensor: The sum of squared differences between the modified field and the reference values.
"""
targets = penalty_func(field_mod, **penalty_args)
penalty = 0.0
epsilon = 1e-8
ref0 = abs(ref_values[target_names[0]].detach()) + epsilon
for i, name in enumerate(target_names):
ref_val = abs(ref_values[name].detach()) + epsilon
factor = ref0 / ref_val
diff = (targets[i] - ref_values[name])**2
penalty += diff * factor
return penalty
[docs]
@staticmethod
def get_opt_params_for_case(
dataset: pyLOMDataset,
i: int,
**kwargs,
)-> dict:
r"""
Get optimization parameters for a specific case.
Args:
dataset (pyLOMDataset): The dataset containing the fields.
i (int): Index of the current case.
**kwargs: Additional arguments.
Returns:
dict: A dictionary containing the optimization parameters.
"""
mapping = kwargs.get('opt_param_config', {})
ref_values = {}
penalty_args = kwargs.get('penalty_args', {})
for original_name, (source_type, func_arg_name) in mapping.items():
if source_type == 'get_variable':
val = dataset.get_variable(original_name)[i]
elif source_type == 'field':
val = dataset[original_name]
else:
raiseError(f"Unknown source type '{source_type}' for variable '{original_name}'")
if func_arg_name in kwargs.get('target_names', []):
ref_values[func_arg_name] = torch.tensor(val)
else:
penalty_args[func_arg_name] = torch.tensor(val)
return {
'ref_values': ref_values,
'penalty_args': penalty_args,
'penalty_func': kwargs.get('penalty_func'),
'target_names': kwargs.get('target_names'),
}
[docs]
def adjust_field(
self,
fieldname: str,
obj_func: callable = objective_mse,
get_opt_param_func: callable = get_opt_params_for_case,
constr_func: callable = multitarget_equality_penalty,
optimizer_class: torch.optim.Optimizer = torch.optim.Adam,
schduler_class: torch.optim.lr_scheduler._LRScheduler = torch.optim.lr_scheduler.StepLR,
opt_config: dict = None,
disp_progress: tuple = (False, 0),
**kwargs
)-> tuple[pyLOMDataset, list]:
r"""
Adjusts a field in the dataset using an optimization algorithm.
Args:
fieldname (str): Name of the field to be adjusted.
obj_func (callable): Objective function to minimize.
get_opt_param_func (callable): Function to get optimization parameters.
constr_func (callable, optional): Constraint function to apply (default: ``None``).
optimizer_class (torch.optim.Optimizer, optional): Optimizer class to use (default: ``torch.optim.Adam``).
schduler_class (torch.optim.lr_scheduler._LRScheduler, optional): Learning rate scheduler class (default: ``torch.optim.lr_scheduler.StepLR``).
opt_config (dict, optional): Configuration for the optimizer (default: ``None``).
- niter (int): Number of iterations (default: ``1000``).
- lr (float): Learning rate (default: ``1e-2``).
- lr_step_size (int): Step size for the learning rate scheduler (default: ``1``).
- lr_gamma (float): Gamma for the learning rate scheduler (default: ``0.999``).
- penalty_factor (float): Penalty factor for the constraint function (default: ``1e5``).
- tolerance (float): Tolerance for early stopping (default: ``1e-9``).
- patience (int): Number of iterations with no improvement before stopping (default: ``10``).
disp_progress (tuple, optional): Tuple containing a boolean for displaying progress and an integer for the display frequency (default: ``(False, 0)``).
**kwargs: Additional arguments for the objective and constraint functions.
Returns:
tuple: A tuple containing the modified dataset and a list of losses for each case.
"""
default_config = {
'niter': 1000,
'lr': 1e-2,
'lr_step_size': 1,
'lr_gamma': 0.999,
'penalty_factor': 1e5,
'tolerance': 1e-9,
'patience': 10,
}
if opt_config is None:
opt_config = {}
config = {**default_config, **opt_config}
field = self.dataset[fieldname]
field_mod = field.copy()
field_losses = []
def closure(colTensor, colTensor0, opt_vars):
optimizer.zero_grad()
obj_loss = obj_func(colTensor, colTensor0, **opt_vars)
if constr_func is not None:
penalty = constr_func(colTensor, **opt_vars)
total_loss = obj_loss + config['penalty_factor'] * penalty
total_loss.backward()
return total_loss, obj_loss, penalty
else:
total_loss = obj_loss
total_loss.backward()
return total_loss, obj_loss, torch.tensor(0.0)
for i, col in enumerate(field.T):
colTensor = torch.tensor(col, requires_grad=True)
colTensor0 = colTensor.clone().detach()
opt_vars = get_opt_param_func(self.dataset, i, **kwargs)
optimizer = optimizer_class([colTensor], lr=config['lr'])
scheduler = schduler_class(optimizer, step_size=config['lr_step_size'], gamma=config['lr_gamma'])
losses = []
if disp_progress[0]:
pprint(0, f"\nCase {i}:")
prev_loss = float('inf')
n_improvement = 0
for epoch in range(config['niter']):
total_loss, obj_loss, penalty = optimizer.step(lambda: closure(colTensor, colTensor0, opt_vars))
scheduler.step()
losses.append([obj_loss.item(), penalty.item(), total_loss.item()])
if disp_progress[0] and (epoch % disp_progress[1] == 0):
pprint(0, f"Epoch {epoch:4}: Total Loss = {total_loss.item():.2e}, Objective = {obj_loss.item():.2e}, Penalty = {penalty.item():.2e}")
loss_diff = abs(prev_loss - total_loss.item())
prev_loss = total_loss.item()
if loss_diff < config['tolerance']:
n_improvement += 1
else:
n_improvement = 0
if n_improvement >= config['patience']:
pprint(0, f"Early stopping at epoch {epoch}, no significant improvement.")
break
if epoch >= config['niter'] - 1:
pprint(0, f"Reached maximum number of epochs ({config['niter']}). Stopping.")
break
field_mod[:, i] = colTensor.detach().numpy()
field_losses.append(losses)
ndim = self.dataset.info(fieldname)['ndim']
self.dataset.add_field(varname=fieldname + 'Adjusted', ndim=ndim, var=field_mod)
return self.dataset, field_losses