Computing discrete Fourier transforms¶
Discrete Fourier transforms are computed through the Fast Fourier Transform method (FFT) implemented in the FFTW library. Module janus.fft
provides a Python wrapper to this C library. This module exposes both serial and parallel (MPI) implementations through a unified interface.
Before the main methods and functions of the janus.fft
module are introduced, an important design issue should be mentioned. In the present implementation of the module, input data (to be transformed) is not passed directly to FFTW. Rather, a local copy is first made, and FFTW then operates on this local copy. This allows reusing the same plan to perform many transforms (which is advantageous in the context of iterative solvers). This certainly induces a performance hit, which is deemed negligible for transforms of large 2D or 3D arrays.
Although not essential, it might be useful to have a look to the FFTW manual. For the time being, only two and three dimensional real-to-complex transforms are implemented.
Serial computations¶
The following piece of code creates an object transform
which can perform real FFTs on 32x64
grids of real numbers.
>>> import janus.fft.serial
>>> transform = janus.fft.serial.create_real((32, 64))
The function janus.fft.serial.create_real()
can be passed planner flags (see Planner Flags in the FFTW manual). The attributes of the returned object are
transform.global_ishape
contains the global shape of the input array,
transform.ishape
contains the local shape of the input (real) array,
transform.global_oshape
contains the global shape of the output (complex) array,
transform.oshape
contains the local shape of the output (complex) array. For serial transforms, local and global output shapes coincide.
For serial transforms, local and global shapes coincide.
>>> transform.global_ishape
(32, 64)
>>> transform.ishape
(32, 64)
>>> transform.global_oshape
(32, 66)
>>> transform.oshape
(32, 66)
It should be noted that complex-valued tables are stored according to the FFTW library: even (resp. odd) values of the fast index correspond to the real (resp. imaginary) part of the complex number (see also Multi-Dimensional DFTs of Real Data in the FFTW manual).
Direct (real-to-complex) transforms are computed through the method transform.r2c()
, which takes as input a MemoryView
of shape transform.ishape
, and returns a MemoryView
of shape transform.oshape
.
>>> import numpy as np
>>> np.random.seed(20150223)
>>> x = np.random.rand(*transform.ishape)
>>> y1 = transform.r2c(x)
It should be noted that y1
is a MemoryView
, not a numpy
array; it can, however, readily be converted into an array
>>> print(y1)
<MemoryView of 'array' object>
>>> y1 = np.asarray(y1)
>>> type(y1)
<class 'numpy.ndarray'>
The output can be converted to an array of complex numbers
>>> actual = y1[..., 0::2] + 1j * y1[..., 1::2]
>>> actual.shape
(32, 33)
and compared to the FFT of x
computed by means of the numpy.fft
module
>>> expected = np.fft.rfftn(x)
>>> expected.shape
(32, 33)
>>> abs_delta = np.absolute(expected - actual)
>>> abs_exp = np.absolute(expected)
>>> error = np.sqrt(np.sum(abs_delta**2) / np.sum(abs_exp**2))
>>> assert error < 1E-15
Inverse discrete Fourier transform is computed through the method transform.c2r()
>>> x1 = transform.c2r(y1)
>>> error = np.sqrt(np.sum((x1 - x)**2) / np.sum(x**2))
>>> assert error < 1E-15
It should be noted that the output array can be passed as an argument to both transform.r2c()
>>> y2 = np.empty(transform.oshape)
>>> out = transform.r2c(x, y2)
>>> assert out.base is y2
>>> assert np.sum((y2 - y1)**2) == 0.0
and transform.c2r()
>>> x2 = np.empty(transform.ishape)
>>> out = transform.c2r(y1, x2)
>>> assert out.base is x2
>>> assert np.sum((x2 - x1)**2) == 0.0
Parallel computations¶
The module janus.fft.parallel
is a wrapper around the fftw3-mpi
library (refer to Distributed-memory FFTW with MPI in the FFTW manual for the inner workings of this library). This module must be used along with the mpi4py module to handle MPI communications.
The Python API is very similar to the API for serial transforms. However, computing a parallel FFT is slightly more involved than computing a serial FFT, because the data must be distributed across the processes. The computation must go through the following steps
create input data (root process),
create a transform object (all processes),
gather local shapes (root process),
scatter the input data according to the previouly gathered local sizes (root process),
compute the transform (all processes),
gather the results (root process).
This is illustrated in the step-by-step tutorial below. This tutorial aims again at computing a 32x64
real Fourier transform. The full source can be downloaded here
, it must be run through the following command line:
$ mpiexec -np 2 python3 parallel_fft_tutorial.py
where the number of processes can be adjusted (all output produced below was obtained with two parallel processes).
Before we proceed with the description of the program, it should be noted that communication will be carried out with the uppercase versions MPI.Comm.Gather and MPI.Comm.Scatter. The lowercase versions of MPI.Comm.scatter and MPI.Comm.gather are slightly easier to use, but communicate objects through pickling. This approach fails with very large objects (the size limit is much lower than the intrinsic MPI size limit). With MPI.Comm.Gather and MPI.Comm.Scatter, the intrinsic MPI size limit is restored. The FFT objects defined in the module janus.fft.parallel
provide attributes to help call these methods.
A few modules must first be imported
import numpy as np
import janus.fft.parallel
from mpi4py import MPI
Then, some useful variables are created
if __name__ == '__main__':
comm = MPI.COMM_WORLD
root = 0
shape = (32, 64)
Then, the transform objects (one for each process) are created (step 2), and their various shapes are printed out.
transform = janus.fft.parallel.create_real(shape, comm)
if comm.rank == root:
print('global_ishape = {}'.format(transform.global_ishape))
print('global_oshape = {}'.format(transform.global_oshape))
print('ishape = {}'.format(transform.ishape))
print('oshape = {}'.format(transform.oshape))
This code snippet outputs the following messages
global_ishape = (32, 64)
global_oshape = (32, 66)
ishape = (16, 64)
oshape = (16, 66)
The transform.shape
attribute refers to the global (logical) shape of the transform. Since the data is distributed across all processes, the local size in memory of the input and output data differ from transform.shape
. Accordingly, the transform.rshape
(resp. transform.cshape
) attribute refers to the local shape of the real, input (resp. complex, output) data, for the current process. As expected with FFTW, it is observed that the data is distributed with respect to the first dimension. Indeed, the global, first dimension is 64, and the above example is run with 2 processes; therefore, the local first dimension is 64 / 2 = 32
.
In order to figure out how to scatter the input data, the root process then gathers all local sizes and displacements, and the parameters to be passed to mpi4py.MPI.Scatterv()
and mpi4py.MPI.Gatherv()
are prepared
counts_and_displs = comm.gather(sendobj=(transform.isize, transform.idispl,
transform.osize, transform.odispl),
root=root)
if comm.rank == root:
np.random.seed(20150310)
x = np.random.rand(*shape)
icounts, idispls, ocounts, odispls = zip(*counts_and_displs)
else:
x, icounts, idispls, ocounts, odispls = None, None, None, None, None
Then the input data x
is scattered across all processes
x_loc = np.empty(transform.ishape, dtype=np.float64)
comm.Scatterv([x, icounts, idispls, MPI.DOUBLE], x_loc, root)
Each process then executes its transform
y_loc = transform.r2c(x_loc)
and the root process finally gathers the results
if comm.rank == root:
y = np.empty(transform.global_oshape, dtype=np.float64)
else:
y = None
comm.Gatherv(y_loc, [y, ocounts, odispls, MPI.DOUBLE], root)
To check that the computation is correct, the same transform is finally computed locally by the root process
if comm.rank == root:
serial_transform = janus.fft.serial.create_real(shape)
y_ref = np.asarray(serial_transform.r2c(x))
err = np.sum((y-y_ref)**2) / np.sum(y_ref**2)
assert err <= np.finfo(np.float64).eps
The complete program¶
# Imports
import numpy as np
import janus.fft.parallel
from mpi4py import MPI
# Init some variables
if __name__ == '__main__':
comm = MPI.COMM_WORLD
root = 0
shape = (32, 64)
# Create transform objects
transform = janus.fft.parallel.create_real(shape, comm)
if comm.rank == root:
print('global_ishape = {}'.format(transform.global_ishape))
print('global_oshape = {}'.format(transform.global_oshape))
print('ishape = {}'.format(transform.ishape))
print('oshape = {}'.format(transform.oshape))
# Prepare communications
counts_and_displs = comm.gather(sendobj=(transform.isize, transform.idispl,
transform.osize, transform.odispl),
root=root)
if comm.rank == root:
np.random.seed(20150310)
x = np.random.rand(*shape)
icounts, idispls, ocounts, odispls = zip(*counts_and_displs)
else:
x, icounts, idispls, ocounts, odispls = None, None, None, None, None
# Scatter input data
x_loc = np.empty(transform.ishape, dtype=np.float64)
comm.Scatterv([x, icounts, idispls, MPI.DOUBLE], x_loc, root)
# Execute transform
y_loc = transform.r2c(x_loc)
# Gather output data
if comm.rank == root:
y = np.empty(transform.global_oshape, dtype=np.float64)
else:
y = None
comm.Gatherv(y_loc, [y, ocounts, odispls, MPI.DOUBLE], root)
# Validate result
if comm.rank == root:
serial_transform = janus.fft.serial.create_real(shape)
y_ref = np.asarray(serial_transform.r2c(x))
err = np.sum((y-y_ref)**2) / np.sum(y_ref**2)
assert err <= np.finfo(np.float64).eps