Skip to content

optimizer

torch_admp.optimizer

Optimization algorithms for torch-admp.

This module implements various optimization algorithms used for charge equilibration and other optimization tasks in the torch-admp package, including line search, conjugate gradient methods, and other optimization utilities.

Perform line search to find optimal step size.

PARAMETER DESCRIPTION
func_value

Function to compute the value of the objective function

TYPE: Callable

func_grads

Function to compute gradients of the objective function

TYPE: Callable

x0

Initial point

TYPE: Tensor

eps

Convergence threshold, by default 1e-6

TYPE: float DEFAULT: 1e-06

fk

Function value at x0, by default None

TYPE: Tensor DEFAULT: None

gk

Gradient at x0, by default None

TYPE: Tensor DEFAULT: None

pk

Search direction, by default None

TYPE: Tensor DEFAULT: None

**kwargs

Additional keyword arguments passed to func_value and func_grads

DEFAULT: {}

RETURNS DESCRIPTION
Tensor

Optimal point found by line search

Source code in torch_admp/optimizer.py
@torch.jit.unused
def line_search(
    func_value: Callable,
    func_grads: Callable,
    x0: torch.Tensor,
    eps: float = 1e-6,
    fk: torch.Tensor = None,
    gk: torch.Tensor = None,
    pk: torch.Tensor = None,
    **kwargs,
) -> torch.Tensor:
    """
    Perform line search to find optimal step size.

    Parameters
    ----------
    func_value : Callable
        Function to compute the value of the objective function
    func_grads : Callable
        Function to compute gradients of the objective function
    x0 : torch.Tensor
        Initial point
    eps : float, optional
        Convergence threshold, by default 1e-6
    fk : torch.Tensor, optional
        Function value at x0, by default None
    gk : torch.Tensor, optional
        Gradient at x0, by default None
    pk : torch.Tensor, optional
        Search direction, by default None
    **kwargs
        Additional keyword arguments passed to func_value and func_grads

    Returns
    -------
    torch.Tensor
        Optimal point found by line search
    """
    history_x = torch.arange(3, dtype=x0.dtype, device=x0.device)
    if fk is None:
        x0 = x0.detach()
        x0.requires_grad = True
        if x0.grad is not None:
            x0.grad.detach_()
            x0.grad.zero_()
        fk = func_value(x0, **kwargs)
        gk = func_grads(fk, x0)
        pk = -gk / torch.norm(gk)
    history_f = [fk]

    xk = x0.detach()
    # xk.requires_grad = True
    for _ in range(2):
        if torch.norm(gk) / xk.shape[0] < eps:
            return xk
        xk = xk + pk
        xk.detach_()
        xk.requires_grad = True
        if xk.grad is not None:
            xk.grad.detach_()
            xk.grad.zero_()
        fk = func_value(xk, **kwargs)
        gk = func_grads(fk, xk)
        fk.detach_()
        history_f.append(fk)

    coeff_matrix = torch.stack(
        [history_x**2, history_x, torch.ones_like(history_x)], dim=1
    )
    y = torch.stack(history_f)
    coeff = torch.linalg.solve(coeff_matrix, y)
    # print(coeff[0])
    x_opt = x0 - coeff[1] / (2 * coeff[0]) * pk
    return x_opt

quadratic_optimize(func_value: Callable, func_grads: Callable, xk: torch.Tensor, eps: float = 0.0001, ls_eps: float = 0.0001, max_iter: int = 20, **kwargs)

Perform quadratic optimization with conjugate gradient method.

PARAMETER DESCRIPTION
func_value

Function to compute the value of the objective function

TYPE: Callable

func_grads

Function to compute gradients of the objective function

TYPE: Callable

xk

Initial point

TYPE: Tensor

eps

Convergence threshold, by default 1e-4

TYPE: float DEFAULT: 0.0001

ls_eps

Line search threshold, by default 1e-4

TYPE: float DEFAULT: 0.0001

max_iter

Maximum number of iterations, by default 20

TYPE: int DEFAULT: 20

**kwargs

Additional keyword arguments passed to func_value and func_grads

DEFAULT: {}

RETURNS DESCRIPTION
tuple

Tuple containing (xk, fk, gk, converge_iter) where: - xk: optimal point - fk: function value at optimal point - gk: gradient at optimal point - converge_iter: iteration at which convergence was achieved

Source code in torch_admp/optimizer.py
@torch.jit.unused
def quadratic_optimize(
    func_value: Callable,
    func_grads: Callable,
    xk: torch.Tensor,
    eps: float = 1e-4,
    ls_eps: float = 1e-4,
    max_iter: int = 20,
    **kwargs,
):
    """
    Perform quadratic optimization with conjugate gradient method.

    Parameters
    ----------
    func_value : Callable
        Function to compute the value of the objective function
    func_grads : Callable
        Function to compute gradients of the objective function
    xk : torch.Tensor
        Initial point
    eps : float, optional
        Convergence threshold, by default 1e-4
    ls_eps : float, optional
        Line search threshold, by default 1e-4
    max_iter : int, optional
        Maximum number of iterations, by default 20
    **kwargs
        Additional keyword arguments passed to func_value and func_grads

    Returns
    -------
    tuple
        Tuple containing (xk, fk, gk, converge_iter) where:
        - xk: optimal point
        - fk: function value at optimal point
        - gk: gradient at optimal point
        - converge_iter: iteration at which convergence was achieved
    """
    converge_iter: int = -1

    if xk.grad is not None:
        xk.grad.detach_()
        xk.grad.zero_()

    fk = func_value(xk, **kwargs)
    gk = func_grads(fk, xk)
    pk = -gk / torch.norm(gk)
    for ii in range(max_iter):
        fk.detach_()
        pk.detach_()
        # Selecting the step length
        x_new = line_search(func_value, func_grads, xk, ls_eps, fk, gk, pk, **kwargs)
        x_new.detach_()
        x_new.requires_grad = True
        if x_new.grad is not None:
            x_new.grad.detach_()
            x_new.grad.zero_()
        fk_new = func_value(x_new, **kwargs)
        gk_new = func_grads(fk_new, x_new)

        xk = x_new
        fk = fk_new

        norm_grad = torch.norm(gk_new) / xk.shape[0]
        # print(norm_grad)
        if norm_grad < eps:
            gk = gk_new
            converge_iter = ii
            break
        else:
            pk = update_pr(gk, pk, gk_new)
            gk = gk_new

    return xk, fk, gk, converge_iter

update_dy(gk: torch.Tensor, pk: torch.Tensor, gk_new: torch.Tensor) -> torch.Tensor

Update search direction using Dai-Yuan Algorithm.

PARAMETER DESCRIPTION
gk

Current gradient

TYPE: Tensor

pk

Current search direction

TYPE: Tensor

gk_new

New gradient

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Updated search direction

Source code in torch_admp/optimizer.py
def update_dy(
    gk: torch.Tensor,
    pk: torch.Tensor,
    gk_new: torch.Tensor,
) -> torch.Tensor:
    """
    Update search direction using Dai-Yuan Algorithm.

    Parameters
    ----------
    gk : torch.Tensor
        Current gradient
    pk : torch.Tensor
        Current search direction
    gk_new : torch.Tensor
        New gradient

    Returns
    -------
    torch.Tensor
        Updated search direction
    """
    old_gk = gk
    gk = gk_new
    chi = torch.linalg.norm(gk) ** 2 / pk.dot(gk - old_gk)
    # Updated descent direction
    pk = -gk + chi * pk
    return pk

update_fr(gk: torch.Tensor, pk: torch.Tensor, gk_new: torch.Tensor) -> torch.Tensor

Update search direction using Fletcher-Reeves Algorithm.

PARAMETER DESCRIPTION
gk

Current gradient

TYPE: Tensor

pk

Current search direction

TYPE: Tensor

gk_new

New gradient

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Updated search direction

Source code in torch_admp/optimizer.py
def update_fr(
    gk: torch.Tensor,
    pk: torch.Tensor,
    gk_new: torch.Tensor,
) -> torch.Tensor:
    """
    Update search direction using Fletcher-Reeves Algorithm.

    Parameters
    ----------
    gk : torch.Tensor
        Current gradient
    pk : torch.Tensor
        Current search direction
    gk_new : torch.Tensor
        New gradient

    Returns
    -------
    torch.Tensor
        Updated search direction
    """
    old_gk = gk
    gk = gk_new
    # Line (16) of the Fletcher-Reeves algorithm
    chi = torch.linalg.norm(gk) ** 2 / torch.linalg.norm(old_gk) ** 2
    # Updated descent direction
    pk = -gk + chi * pk
    return pk

update_hs(gk: torch.Tensor, pk: torch.Tensor, gk_new: torch.Tensor) -> torch.Tensor

Update search direction using Hestenes-Stiefel Algorithm.

PARAMETER DESCRIPTION
gk

Current gradient

TYPE: Tensor

pk

Current search direction

TYPE: Tensor

gk_new

New gradient

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Updated search direction

Source code in torch_admp/optimizer.py
def update_hs(
    gk: torch.Tensor,
    pk: torch.Tensor,
    gk_new: torch.Tensor,
) -> torch.Tensor:
    """
    Update search direction using Hestenes-Stiefel Algorithm.

    Parameters
    ----------
    gk : torch.Tensor
        Current gradient
    pk : torch.Tensor
        Current search direction
    gk_new : torch.Tensor
        New gradient

    Returns
    -------
    torch.Tensor
        Updated search direction
    """
    old_gk = gk
    gk = gk_new
    chi = gk.dot(gk - old_gk) / pk.dot(gk - old_gk)
    # Updated descent direction
    pk = -gk + chi * pk
    return pk

update_hz(gk: torch.Tensor, pk: torch.Tensor, gk_new: torch.Tensor) -> torch.Tensor

Update search direction using Hager-Zhang Algorithm.

PARAMETER DESCRIPTION
gk

Current gradient

TYPE: Tensor

pk

Current search direction

TYPE: Tensor

gk_new

New gradient

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Updated search direction

Source code in torch_admp/optimizer.py
def update_hz(
    gk: torch.Tensor,
    pk: torch.Tensor,
    gk_new: torch.Tensor,
) -> torch.Tensor:
    """
    Update search direction using Hager-Zhang Algorithm.

    Parameters
    ----------
    gk : torch.Tensor
        Current gradient
    pk : torch.Tensor
        Current search direction
    gk_new : torch.Tensor
        New gradient

    Returns
    -------
    torch.Tensor
        Updated search direction
    """
    old_gk = gk
    gk = gk_new
    delta_gk = gk - old_gk
    m = delta_gk - 2 * pk * torch.linalg.norm(delta_gk) ** 2 / pk.dot(delta_gk)
    n = gk / pk.dot(delta_gk)
    chi = m.dot(n)
    # Updated descent direction
    pk = -gk + chi * pk
    return pk

update_pr(gk: torch.Tensor, pk: torch.Tensor, gk_new: torch.Tensor) -> torch.Tensor

Update search direction using Polak-Ribiere Algorithm.

PARAMETER DESCRIPTION
gk

Current gradient

TYPE: Tensor

pk

Current search direction

TYPE: Tensor

gk_new

New gradient

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Updated search direction

Source code in torch_admp/optimizer.py
def update_pr(
    gk: torch.Tensor,
    pk: torch.Tensor,
    gk_new: torch.Tensor,
) -> torch.Tensor:
    """
    Update search direction using Polak-Ribiere Algorithm.

    Parameters
    ----------
    gk : torch.Tensor
        Current gradient
    pk : torch.Tensor
        Current search direction
    gk_new : torch.Tensor
        New gradient

    Returns
    -------
    torch.Tensor
        Updated search direction
    """
    old_gk = gk
    gk = gk_new
    # Line (16) of the Polak-Ribiere Algorithm
    chi = (gk - old_gk).dot(gk) / torch.linalg.norm(old_gk) ** 2
    chi = torch.where(chi > 0, chi, torch.zeros_like(chi))
    # Updated descent direction
    pk = -gk + chi * pk
    return pk

update_sd(gk: torch.Tensor, pk: torch.Tensor, gk_new: torch.Tensor) -> torch.Tensor

Update search direction using Steepest Descent Algorithm.

PARAMETER DESCRIPTION
gk

Current gradient

TYPE: Tensor

pk

Current search direction

TYPE: Tensor

gk_new

New gradient

TYPE: Tensor

RETURNS DESCRIPTION
Tensor

Updated search direction

Source code in torch_admp/optimizer.py
def update_sd(
    gk: torch.Tensor,
    pk: torch.Tensor,
    gk_new: torch.Tensor,
) -> torch.Tensor:
    """
    Update search direction using Steepest Descent Algorithm.

    Parameters
    ----------
    gk : torch.Tensor
        Current gradient
    pk : torch.Tensor
        Current search direction
    gk_new : torch.Tensor
        New gradient

    Returns
    -------
    torch.Tensor
        Updated search direction
    """
    gk = gk_new
    # Selection of the direction of the steepest descent
    pk = -gk / torch.linalg.norm(gk)
    return pk