Computing discrete Fourier transforms

Discrete Fourier transforms are computed through the Fast Fourier Transform method (FFT) implemented in the FFTW library. Module janus.fft provides a Python wrapper to this C library. This module exposes both serial and parallel (MPI) implementations through a unified interface.

Before the main methods and functions of the janus.fft module are introduced, an important design issue should be mentioned. In the present implementation of the module, input data (to be transformed) is not passed directly to FFTW. Rather, a local copy is first made, and FFTW then operates on this local copy. This allows reusing the same plan to perform many transforms (which is advantageous in the context of iterative solvers). This certainly induces a performance hit, which is deemed negligible for transforms of large 2D or 3D arrays.

Although not essential, it might be useful to have a look to the FFTW manual. For the time being, only two and three dimensional real-to-complex transforms are implemented.

Serial computations

The following piece of code creates an object transform which can perform real FFTs on 32x64 grids of real numbers.

>>> import janus.fft.serial
>>> transform = janus.fft.serial.create_real((32, 64))

The function janus.fft.serial.create_real() can be passed planner flags (see Planner Flags in the FFTW manual). The attributes of the returned object are

  • transform.global_ishape contains the global shape of the input array,

  • transform.ishape contains the local shape of the input (real) array,

  • transform.global_oshape contains the global shape of the output (complex) array,

  • transform.oshape contains the local shape of the output (complex) array. For serial transforms, local and global output shapes coincide.

For serial transforms, local and global shapes coincide.

>>> transform.global_ishape
(32, 64)
>>> transform.ishape
(32, 64)
>>> transform.global_oshape
(32, 66)
>>> transform.oshape
(32, 66)

It should be noted that complex-valued tables are stored according to the FFTW library: even (resp. odd) values of the fast index correspond to the real (resp. imaginary) part of the complex number (see also Multi-Dimensional DFTs of Real Data in the FFTW manual).

Direct (real-to-complex) transforms are computed through the method transform.r2c(), which takes as input a MemoryView of shape transform.ishape, and returns a MemoryView of shape transform.oshape.

>>> import numpy as np
>>> np.random.seed(20150223)
>>> x = np.random.rand(*transform.ishape)
>>> y1 = transform.r2c(x)

It should be noted that y1 is a MemoryView, not a numpy array; it can, however, readily be converted into an array

>>> print(y1)
<MemoryView of 'array' object>
>>> y1 = np.asarray(y1)
>>> type(y1)
<class 'numpy.ndarray'>

The output can be converted to an array of complex numbers

>>> actual = y1[..., 0::2] + 1j * y1[..., 1::2]
>>> actual.shape
(32, 33)

and compared to the FFT of x computed by means of the numpy.fft module

>>> expected = np.fft.rfftn(x)
>>> expected.shape
(32, 33)
>>> abs_delta = np.absolute(expected - actual)
>>> abs_exp = np.absolute(expected)
>>> error = np.sqrt(np.sum(abs_delta**2) / np.sum(abs_exp**2))
>>> assert error < 1E-15

Inverse discrete Fourier transform is computed through the method transform.c2r()

>>> x1 = transform.c2r(y1)
>>> error = np.sqrt(np.sum((x1 - x)**2) / np.sum(x**2))
>>> assert error < 1E-15

It should be noted that the output array can be passed as an argument to both transform.r2c()

>>> y2 = np.empty(transform.oshape)
>>> out = transform.r2c(x, y2)
>>> assert out.base is y2
>>> assert np.sum((y2 - y1)**2) == 0.0

and transform.c2r()

>>> x2 = np.empty(transform.ishape)
>>> out = transform.c2r(y1, x2)
>>> assert out.base is x2
>>> assert np.sum((x2 - x1)**2) == 0.0

Parallel computations

The module janus.fft.parallel is a wrapper around the fftw3-mpi library (refer to Distributed-memory FFTW with MPI in the FFTW manual for the inner workings of this library). This module must be used along with the mpi4py module to handle MPI communications.

The Python API is very similar to the API for serial transforms. However, computing a parallel FFT is slightly more involved than computing a serial FFT, because the data must be distributed across the processes. The computation must go through the following steps

  1. create input data (root process),

  2. create a transform object (all processes),

  3. gather local shapes (root process),

  4. scatter the input data according to the previouly gathered local sizes (root process),

  5. compute the transform (all processes),

  6. gather the results (root process).

This is illustrated in the step-by-step tutorial below. This tutorial aims again at computing a 32x64 real Fourier transform. The full source can be downloaded here, it must be run through the following command line:

$ mpiexec -np 2 python3 parallel_fft_tutorial.py

where the number of processes can be adjusted (all output produced below was obtained with two parallel processes).

Before we proceed with the description of the program, it should be noted that communication will be carried out with the uppercase versions MPI.Comm.Gather and MPI.Comm.Scatter. The lowercase versions of MPI.Comm.scatter and MPI.Comm.gather are slightly easier to use, but communicate objects through pickling. This approach fails with very large objects (the size limit is much lower than the intrinsic MPI size limit). With MPI.Comm.Gather and MPI.Comm.Scatter, the intrinsic MPI size limit is restored. The FFT objects defined in the module janus.fft.parallel provide attributes to help call these methods.

A few modules must first be imported

import numpy as np

import janus.fft.parallel

from mpi4py import MPI

Then, some useful variables are created

if __name__ == '__main__':
    comm = MPI.COMM_WORLD
    root = 0
    shape = (32, 64)

Then, the transform objects (one for each process) are created (step 2), and their various shapes are printed out.

    transform = janus.fft.parallel.create_real(shape, comm)
    if comm.rank == root:
        print('global_ishape  = {}'.format(transform.global_ishape))
        print('global_oshape  = {}'.format(transform.global_oshape))
        print('ishape = {}'.format(transform.ishape))
        print('oshape = {}'.format(transform.oshape))

This code snippet outputs the following messages

global_ishape  = (32, 64)
global_oshape  = (32, 66)
ishape = (16, 64)
oshape = (16, 66)

The transform.shape attribute refers to the global (logical) shape of the transform. Since the data is distributed across all processes, the local size in memory of the input and output data differ from transform.shape. Accordingly, the transform.rshape (resp. transform.cshape) attribute refers to the local shape of the real, input (resp. complex, output) data, for the current process. As expected with FFTW, it is observed that the data is distributed with respect to the first dimension. Indeed, the global, first dimension is 64, and the above example is run with 2 processes; therefore, the local first dimension is 64 / 2 = 32.

In order to figure out how to scatter the input data, the root process then gathers all local sizes and displacements, and the parameters to be passed to mpi4py.MPI.Scatterv() and mpi4py.MPI.Gatherv() are prepared

    counts_and_displs = comm.gather(sendobj=(transform.isize, transform.idispl,
                                             transform.osize, transform.odispl),
                                    root=root)
    if comm.rank == root:
        np.random.seed(20150310)
        x = np.random.rand(*shape)
        icounts, idispls, ocounts, odispls = zip(*counts_and_displs)
    else:
        x, icounts, idispls, ocounts, odispls = None, None, None, None, None

Then the input data x is scattered across all processes

    x_loc = np.empty(transform.ishape, dtype=np.float64)
    comm.Scatterv([x, icounts, idispls, MPI.DOUBLE], x_loc, root)

Each process then executes its transform

    y_loc = transform.r2c(x_loc)

and the root process finally gathers the results

    if comm.rank == root:
        y = np.empty(transform.global_oshape, dtype=np.float64)
    else:
        y = None
    comm.Gatherv(y_loc, [y, ocounts, odispls, MPI.DOUBLE], root)

To check that the computation is correct, the same transform is finally computed locally by the root process

    if comm.rank == root:
        serial_transform = janus.fft.serial.create_real(shape)
        y_ref = np.asarray(serial_transform.r2c(x))
        err = np.sum((y-y_ref)**2) / np.sum(y_ref**2)
        assert err <= np.finfo(np.float64).eps

The complete program

# Imports
import numpy as np

import janus.fft.parallel

from mpi4py import MPI

# Init some variables
if __name__ == '__main__':
    comm = MPI.COMM_WORLD
    root = 0
    shape = (32, 64)
    # Create transform objects
    transform = janus.fft.parallel.create_real(shape, comm)
    if comm.rank == root:
        print('global_ishape  = {}'.format(transform.global_ishape))
        print('global_oshape  = {}'.format(transform.global_oshape))
        print('ishape = {}'.format(transform.ishape))
        print('oshape = {}'.format(transform.oshape))
    # Prepare communications
    counts_and_displs = comm.gather(sendobj=(transform.isize, transform.idispl,
                                             transform.osize, transform.odispl),
                                    root=root)
    if comm.rank == root:
        np.random.seed(20150310)
        x = np.random.rand(*shape)
        icounts, idispls, ocounts, odispls = zip(*counts_and_displs)
    else:
        x, icounts, idispls, ocounts, odispls = None, None, None, None, None
    # Scatter input data
    x_loc = np.empty(transform.ishape, dtype=np.float64)
    comm.Scatterv([x, icounts, idispls, MPI.DOUBLE], x_loc, root)
    # Execute transform
    y_loc = transform.r2c(x_loc)
    # Gather output data
    if comm.rank == root:
        y = np.empty(transform.global_oshape, dtype=np.float64)
    else:
        y = None
    comm.Gatherv(y_loc, [y, ocounts, odispls, MPI.DOUBLE], root)
    # Validate result
    if comm.rank == root:
        serial_transform = janus.fft.serial.create_real(shape)
        y_ref = np.asarray(serial_transform.r2c(x))
        err = np.sum((y-y_ref)**2) / np.sum(y_ref**2)
        assert err <= np.finfo(np.float64).eps