#!/usr/bin/env python
#
# pyLOM, utils.
#
# Parallel MPI routines
#
# Last rev: 14/02/2025
from __future__ import print_function, division
import mpi4py, numpy as np
mpi4py.rc.recv_mprobe = False
from mpi4py import MPI
from .parall import split
from .nvtxp import nvtxp
MPI_COMM = MPI.COMM_WORLD
MPI_RANK = MPI_COMM.Get_rank()
MPI_SIZE = MPI_COMM.Get_size()
MPI_RDONLY = MPI.MODE_RDONLY
MPI_WRONLY = MPI.MODE_WRONLY
MPI_CREATE = MPI.MODE_CREATE
# Expose functions from MPI library
mpi_create_op = MPI.Op.Create
mpi_wtime = MPI.Wtime
mpi_file_open = MPI.File.Open
mpi_nanmin = mpi_create_op(lambda v1,v2,dtype : np.nanmin([v1,v2]),commute=True)
mpi_nanmax = mpi_create_op(lambda v1,v2,dtype : np.nanmax([v1,v2]),commute=True)
mpi_nansum = mpi_create_op(lambda v1,v2,dtype : np.nansum([v1,v2]),commute=True)
[docs]
def mpi_barrier():
'''
Implements the barrier
'''
MPI_COMM.Barrier()
[docs]
@nvtxp('mpi_send',color='red')
def mpi_send(f,dest,tag=0):
'''
Implements the send operation
'''
MPI_COMM.send(f,dest,tag=tag)
[docs]
@nvtxp('mpi_recv',color='red')
def mpi_recv(**kwargs):
'''
Implements the recieve operation
'''
return MPI_COMM.recv(**kwargs)
[docs]
@nvtxp('mpi_sendrecv',color='red')
def mpi_sendrecv(buff,**kwargs):
'''
Implements the sendrecv operation
'''
return MPI_COMM.sendrecv(buff,**kwargs)
[docs]
@nvtxp('mpi_scatter',color='red')
def mpi_scatter(sendbuff,root=0,do_split=False):
'''
Send an array among the processors and split
if necessary.
'''
if MPI_SIZE > 1:
return MPI_COMM.scatter(split(sendbuff,root=root),root=root) if do_split else MPI_COMM.scatter(sendbuff,root=root)
return sendbuff
[docs]
@nvtxp('mpi_gather',color='red')
def mpi_gather(sendbuff,root=0,all=False):
'''
Gather an array from all the processors.
'''
if MPI_SIZE > 1:
if not isinstance(sendbuff,np.ndarray) and not isinstance(sendbuff,list): sendbuff = [sendbuff]
if all:
out = MPI_COMM.allgather(sendbuff)
return np.concatenate(out,axis=0)
else:
out = MPI_COMM.gather(sendbuff,root=root)
return np.concatenate(out,axis=0) if MPI_RANK == root else None
return sendbuff
[docs]
@nvtxp('mpi_reduce',color='red')
def mpi_reduce(sendbuff,root=0,op='sum',all=False):
'''
Reduce an array from all the processors.
'''
if MPI_SIZE > 1:
if isinstance(op,str):
if 'sum' in op: opf = MPI.SUM
if 'max' in op: opf = MPI.MAX
if 'min' in op: opf = MPI.MIN
if 'nanmin' in op: opf = mpi_nanmin
if 'nanmax' in op: opf = mpi_nanmax
if 'argmin' in op: opf = MPI.MINLOC
if 'argmax' in op: opf = MPI.MAXLOC
if 'nansum' in op: opf = mpi_nansum
else:
opf = op
if all:
return MPI_COMM.allreduce(sendbuff,op=opf)
else:
out = MPI_COMM.reduce(sendbuff,op=opf,root=root)
return out if root == MPI_RANK else sendbuff
else:
return sendbuff
[docs]
@nvtxp('mpi_bcast',color='red')
def mpi_bcast(sendbuff,root=0):
'''
Implements the broadcast operation
'''
return MPI_COMM.bcast(sendbuff,root=root)