Skip to content

pme

torch_admp.pme

Particle Mesh Ewald (PME) implementation for torch-admp.

This module implements the Coulomb energy calculation using the Particle Mesh Ewald method, which splits the calculation into real-space and reciprocal-space contributions for improved efficiency in periodic systems. It includes support for slab corrections and various optimization methods.

CoulombForceModule(rcut: float, ethresh: float = 1e-05, kspace: bool = True, rspace: bool = True, slab_corr: bool = False, slab_axis: int = 2, units_dict: Optional[Dict] = None, sel: Optional[list[int]] = None, kappa: Optional[float] = None, spacing: Union[List[float], float, None] = None, kmesh: Union[List[int], int, None] = None)

Bases: BaseForceModule

Coulomb energy module with Particle Mesh Ewald (PME).

This module implements the Coulomb energy calculation using the Particle Mesh Ewald method, which splits the calculation into real-space and reciprocal-space contributions for improved efficiency in periodic systems.

PARAMETER DESCRIPTION
rcut

Real-space cutoff distance

TYPE: float

ethresh

Energy threshold for PME accuracy, by default 1e-5

TYPE: float DEFAULT: 1e-05

kspace

Whether to include reciprocal space contribution, by default True

TYPE: bool DEFAULT: True

rspace

Whether to include real space contribution, by default True

TYPE: bool DEFAULT: True

slab_corr

Whether to apply slab correction, by default False

TYPE: bool DEFAULT: False

slab_axis

Axis at which the slab correction is applied, by default 2

TYPE: int DEFAULT: 2

units_dict

Dictionary of unit conversions, by default None

TYPE: Optional[Dict] DEFAULT: None

sel

Selection list for neighbor list, by default None

TYPE: Optional[list[int]] DEFAULT: None

kappa

Inverse screening length [Å^-1], by default None

TYPE: Optional[float] DEFAULT: None

spacing

Grid spacing for reciprocal space, by default None

TYPE: Optional[List[float]] DEFAULT: None

Initialize the CoulombForceModule with PME.

PARAMETER DESCRIPTION
rcut

Real-space cutoff distance

TYPE: float

ethresh

Energy threshold for PME accuracy, by default 1e-5

TYPE: float DEFAULT: 1e-05

kspace

Whether to include reciprocal space contribution, by default True

TYPE: bool DEFAULT: True

rspace

Whether to include real space contribution, by default True

TYPE: bool DEFAULT: True

slab_corr

Whether to apply slab correction, by default False

TYPE: bool DEFAULT: False

slab_axis

Axis at which the slab correction is applied, by default 2

TYPE: int DEFAULT: 2

units_dict

Dictionary of unit conversions, by default None

TYPE: Optional[Dict] DEFAULT: None

sel

Selection list for neighbor list, by default None

TYPE: Optional[list[int]] DEFAULT: None

kappa

Inverse screening length [Å^-1], by default None

TYPE: Optional[float] DEFAULT: None

spacing

Grid spacing for reciprocal space, by default None

TYPE: Optional[List[float]] DEFAULT: None

Source code in torch_admp/pme.py
def __init__(
    self,
    rcut: float,
    ethresh: float = 1e-5,
    kspace: bool = True,
    rspace: bool = True,
    slab_corr: bool = False,
    slab_axis: int = 2,
    units_dict: Optional[Dict] = None,
    sel: Optional[list[int]] = None,
    kappa: Optional[float] = None,
    spacing: Union[List[float], float, None] = None,
    kmesh: Union[List[int], int, None] = None,
) -> None:
    """
    Initialize the CoulombForceModule with PME.

    Parameters
    ----------
    rcut : float
        Real-space cutoff distance
    ethresh : float, optional
        Energy threshold for PME accuracy, by default 1e-5
    kspace : bool, optional
        Whether to include reciprocal space contribution, by default True
    rspace : bool, optional
        Whether to include real space contribution, by default True
    slab_corr : bool, optional
        Whether to apply slab correction, by default False
    slab_axis : int, optional
        Axis at which the slab correction is applied, by default 2
    units_dict : Optional[Dict], optional
        Dictionary of unit conversions, by default None
    sel : Optional[list[int]], optional
        Selection list for neighbor list, by default None
    kappa : Optional[float], optional
        Inverse screening length [Å^-1], by default None
    spacing : Optional[List[float]], optional
        Grid spacing for reciprocal space, by default None
    """
    BaseForceModule.__init__(self, units_dict)

    if rcut <= 0.0:
        raise ValueError(f"rcut must be positive, got {rcut}")

    if ethresh <= 0.0:
        raise ValueError(f"ethresh must be positive, got {ethresh}")

    if slab_axis not in (0, 1, 2):
        raise ValueError(f"slab_axis must be 0/1/2, got {slab_axis}")

    self.kspace_flag = kspace
    if kappa is not None:
        if kappa <= 0.0:
            raise ValueError(f"kappa must be positive, got {kappa}")
        self.kappa = kappa
    else:
        if self.kspace_flag:
            kappa = math.sqrt(-math.log(2 * ethresh)) / rcut
            self.kappa = kappa / getattr(self.const_lib, "length_coeff")
        else:
            self.kappa = 0.0
    self.ethresh = ethresh

    if kmesh is not None:
        # use user-defined kmesh
        if isinstance(kmesh, int):
            kmesh = [kmesh, kmesh, kmesh]
        # Validate kmesh values
        for i, k in enumerate(kmesh):
            if k <= 0:
                raise ValueError(
                    f"kmesh values must be positive, got kmesh[{i}] = {k}"
                )
        self.kmesh = to_torch_tensor(np.array(kmesh)).to(torch.long)
    else:
        self.kmesh = kmesh
    # record the actually used kmesh
    self._kmesh = torch.zeros(3, device=DEVICE, dtype=torch.long)
    # use spacing
    if spacing is not None:
        if isinstance(spacing, float):
            spacing = [spacing, spacing, spacing]
        # Validate spacing values
        for i, s in enumerate(spacing):
            if s <= 0:
                raise ValueError(
                    f"spacing values must be positive, got spacing[{i}] = {s}"
                )
        self.spacing = to_torch_tensor(np.array(spacing)).to(
            GLOBAL_PT_FLOAT_PRECISION
        )
    else:
        self.spacing = spacing

    self.rspace_flag = rspace
    self.slab_corr_flag = slab_corr
    self.slab_axis = slab_axis

    self.real_energy = to_torch_tensor(np.zeros(1)).to(GLOBAL_PT_FLOAT_PRECISION)
    self.reciprocal_energy = to_torch_tensor(np.zeros(1)).to(
        GLOBAL_PT_FLOAT_PRECISION
    )
    self.self_energy = to_torch_tensor(np.zeros(1)).to(GLOBAL_PT_FLOAT_PRECISION)
    self.non_neutral_energy = to_torch_tensor(np.zeros(1)).to(
        GLOBAL_PT_FLOAT_PRECISION
    )
    self.slab_corr_energy = to_torch_tensor(np.zeros(1)).to(
        GLOBAL_PT_FLOAT_PRECISION
    )

    # Currently only supprots pme_order=6
    # Because only the 6-th order spline function is hard implemented
    self.pme_order: int = 6
    n_mesh = int(self.pme_order**3)

    # global variables for the reciprocal module, all related to pme_order
    bspline_range = torch.arange(
        -self.pme_order // 2, self.pme_order // 2, device=DEVICE
    )
    shift_y, shift_x, shift_z = torch.meshgrid(
        bspline_range, bspline_range, bspline_range, indexing="ij"
    )
    self.pme_shifts = (
        torch.stack((shift_x, shift_y, shift_z))
        .transpose(0, 3)
        .reshape((1, n_mesh, 3))
    )

    self.rcut = rcut
    self.sel = sel

get_rcut() -> float

Get the cutoff radius.

RETURNS DESCRIPTION
float

Cutoff radius

Source code in torch_admp/pme.py
def get_rcut(self) -> float:
    """
    Get the cutoff radius.

    Returns
    -------
    float
        Cutoff radius
    """
    return self.rcut

get_sel() -> Optional[list[int]]

Get sel list of DP model.

RETURNS DESCRIPTION
Optional[list[int]]

The number of selected neighbors for each type of atom.

Source code in torch_admp/pme.py
def get_sel(self) -> Optional[list[int]]:
    """
    Get `sel` list of DP model.

    Returns
    -------
    Optional[list[int]]
        The number of selected neighbors for each type of atom.
    """
    return self.sel

setup_ewald_parameters(rcut: float, box: Union[torch.Tensor, np.ndarray, None] = None, threshold: float = 1e-05, spacing: Optional[float] = None, method: str = 'openmm') -> Tuple[float, int, int, int]

Given the cutoff distance, and the required precision, determine the parameters used in Ewald sum, including: kappa, kx, ky, and kz.

PARAMETER DESCRIPTION
rcut

Cutoff distance

TYPE: float

threshold

Expected average relative errors in force

TYPE: float DEFAULT: 1e-05

box

Lattice vectors in (3 x 3) matrix Keep unit consistent with rcut

TYPE: Tensor or ndarray DEFAULT: None

spacing

Fourier spacing to determine K, used in gromacs method Keep unit consistent with rcut

TYPE: float DEFAULT: None

method

Method to determine ewald parameters. Valid values: "openmm" or "gromacs". If openmm, the algorithm can refer to http://docs.openmm.org/latest/userguide/theory/02_standard_forces.html#coulomb-interaction-with-particle-mesh-ewald If gromacs, the algorithm is adapted from gromacs source code

TYPE: str DEFAULT: 'openmm'

RETURNS DESCRIPTION
kappa

Ewald parameter, in 1/lenght unit

TYPE: float

kx, ky, kz: int

number of the k-points mesh

Source code in torch_admp/pme.py
def setup_ewald_parameters(
    rcut: float,
    box: Union[torch.Tensor, np.ndarray, None] = None,
    threshold: float = 1e-5,
    spacing: Optional[float] = None,
    method: str = "openmm",
) -> Tuple[float, int, int, int]:
    """
    Given the cutoff distance, and the required precision, determine the parameters used in
    Ewald sum, including: kappa, kx, ky, and kz.

    Parameters
    ----------
    rcut : float
        Cutoff distance
    threshold : float
        Expected average relative errors in force
    box : torch.Tensor or np.ndarray
        Lattice vectors in (3 x 3) matrix
        Keep unit consistent with rcut
    spacing : float, optional
        Fourier spacing to determine K, used in gromacs method
        Keep unit consistent with rcut
    method : str
        Method to determine ewald parameters.
        Valid values: "openmm" or "gromacs".
        If openmm, the algorithm can refer to http://docs.openmm.org/latest/userguide/theory/02_standard_forces.html#coulomb-interaction-with-particle-mesh-ewald
        If gromacs, the algorithm is adapted from gromacs source code

    Returns
    -------
    kappa: float
        Ewald parameter, in 1/lenght unit
    kx, ky, kz: int
        number of the k-points mesh
    """
    if rcut <= 0.0:
        raise ValueError(f"rcut must be positive, got {rcut}")

    if box is None:
        return 0.1, 1, 1, 1

    if isinstance(box, torch.Tensor):
        box = to_numpy_array(box)

    # assert orthogonal box
    assert (
        np.inner(box[0], box[1]) == 0.0
    ), "Only orthogonal box is supported currently."
    assert (
        np.inner(box[0], box[2]) == 0.0
    ), "Only orthogonal box is supported currently."
    assert (
        np.inner(box[1], box[2]) == 0.0
    ), "Only orthogonal box is supported currently."

    if method == "openmm":
        kappa = np.sqrt(-np.log(2 * threshold)) / rcut
        kx = np.ceil(2 * kappa * box[0, 0] / (3.0 * threshold ** (1.0 / 5.0))).astype(
            int
        )
        ky = np.ceil(2 * kappa * box[1, 1] / (3.0 * threshold ** (1.0 / 5.0))).astype(
            int
        )
        kz = np.ceil(2 * kappa * box[2, 2] / (3.0 * threshold ** (1.0 / 5.0))).astype(
            int
        )
    elif method == "gromacs":
        assert spacing is not None, "Spacing must be provided for gromacs method."
        # determine kappa
        kappa = 5.0
        i = 0
        while special.erfc(kappa * rcut) > threshold:
            i += 1
            kappa *= 2

        n = i + 60
        low = 0.0
        high = kappa
        for _ in range(n):
            kappa = (low + high) / 2
            if special.erfc(kappa * rcut) > threshold:
                low = kappa
            else:
                high = kappa
        # determine K
        kx = np.ceil(box[0, 0] / spacing).astype(int)
        ky = np.ceil(box[1, 1] / spacing).astype(int)
        kz = np.ceil(box[2, 2] / spacing).astype(int)
    else:
        raise ValueError(
            f"Invalid method: {method}." "Valid methods: 'openmm', 'gromacs'"
        )

    return kappa, kx, ky, kz