Source code for

# SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause

from datetime import datetime
from functools import partial
from typing import Any, Callable, NamedTuple, Optional, Tuple, Union

import jax
from jax import numpy as jnp

from .logger import logger
from .misc import doc_from
from .tree_math import assert_arithmetics, result_type
from .tree_math import norm as jft_norm
from .tree_math import size, vdot, where, zeros_like

HessVP = Callable[[jnp.ndarray], jnp.ndarray]

N_RESET = 20

[docs] class CGResults(NamedTuple): x: jnp.ndarray nit: Union[int, jnp.ndarray] nfev: Union[int, jnp.ndarray] # number of matrix-evaluations info: Union[int, jnp.ndarray] success: Union[bool, jnp.ndarray]
[docs] def cg(mat, j, x0=None, *args, **kwargs) -> Tuple[Any, Union[int, jnp.ndarray]]: """Solve `mat(x) = j` using Conjugate Gradient. `mat` must be callable and represent a hermitian, positive definite matrix. Notes ----- If set, the parameters `absdelta` and `resnorm` always take precedence over `tol` and `atol`. """ assert_arithmetics(j) if x0 is not None: assert_arithmetics(x0) cg_res = _cg(mat, j, x0, *args, **kwargs) return cg_res.x,
[docs] @doc_from(cg) def static_cg(mat, j, x0=None, *args, **kwargs): assert_arithmetics(j) if x0 is not None: assert_arithmetics(x0) cg_res = _static_cg(mat, j, x0, *args, **kwargs) return cg_res.x,
def _cg_pretty_print_it( name, i, *, energy, energy_diff, absdelta=None, norm=None, resnorm=None, maxiter=None ): if maxiter is not None and i == maxiter: i_str = "✖" * len(str(i)) + f" ({i})" else: i_str = str(i) msg = f"{name}: Iteration {i_str} ⛰:{energy:+.4e} Δ⛰:{energy_diff:.4e}" msg += f" ➽:{absdelta:.4e}" if absdelta is not None else "" if norm is not None and resnorm is not None: msg += f" |∇|:{norm:.4e} ➽:{resnorm:.4e}" # Taken from nifty def _cg( mat, j, x0=None, *, absdelta=None, resnorm=None, norm_ord=None, tol=1e-5, # taken from SciPy's atol=0., miniter=None, maxiter=None, name=None, time_threshold=None, _raise_nonposdef=True ) -> CGResults: norm_ord = 2 if norm_ord is None else norm_ord # TODO: change to 1 maxiter_fallback = 20 * size(j) # taken from SciPy's NewtonCG minimzer miniter = min( (6, maxiter if maxiter is not None else maxiter_fallback) ) if miniter is None else miniter maxiter = max( (min((200, maxiter_fallback)), miniter) ) if maxiter is None else maxiter if absdelta is None and resnorm is None: # fallback convergence criterion resnorm = jnp.maximum(tol * jft_norm(j, ord=norm_ord), atol) common_dtp = result_type(j) eps = 6. * jnp.finfo(common_dtp).eps # taken from SciPy's NewtonCG minimzer tiny = 6. * jnp.finfo(common_dtp).tiny if x0 is None: pos = zeros_like(j) r = -j d = r # energy = .5xT M x - xT j energy = 0. nfev = 0 else: pos = x0 r = mat(pos) - j d = r energy = float(vdot((r - j) / 2, pos)) nfev = 1 previous_gamma = float(vdot(r, r)) info = -1 i = 0 energy_diff = jnp.inf norm = None pp = partial( _cg_pretty_print_it, name, absdelta=absdelta, resnorm=resnorm, maxiter=maxiter ) if name is not None: if resnorm is not None: norm = jft_norm(r, ord=norm_ord) else: norm = None pp(i, energy=energy, energy_diff=energy_diff, norm=norm) if previous_gamma == 0: info = 0 return CGResults(x=pos, info=info, nit=0, nfev=nfev, success=True) for i in range(1, maxiter + 1): q = mat(d) nfev += 1 curv = float(vdot(d, q)) if curv == 0.: if _raise_nonposdef: nm = "CG" if name is None else name raise ValueError(f"{nm}: zero curvature") info = 0 break elif curv < 0.: if _raise_nonposdef: nm = "CG" if name is None else name raise ValueError(f"{nm}: negative curvature") if i > 1: info = 0 break else: pos = previous_gamma / (-curv) * j info = 0 break alpha = previous_gamma / curv pos = pos - alpha * d if i % N_RESET == 0: r = mat(pos) - j nfev += 1 else: r = r - q * alpha gamma = float(vdot(r, r)) if time_threshold is not None and > time_threshold: info = i break if gamma >= 0. and gamma <= tiny: nm = "CG" if name is None else name logger.warning(f"{nm}: gamma=0, converged!") info = 0 break if resnorm is not None: norm = float(jft_norm(r, ord=norm_ord)) if norm < resnorm and i >= miniter: info = 0 break else: norm = None new_energy = float(vdot((r - j) / 2, pos)) energy_diff = energy - new_energy neg_energy_eps = -eps * jnp.abs(new_energy) if energy_diff < neg_energy_eps: nm = "CG" if name is None else name if _raise_nonposdef: raise ValueError(f"{nm}: WARNING: energy increased") logger.error(f"{nm}: WARNING: energy increased") info = i break if absdelta is not None and energy_diff < absdelta and i >= miniter: info = 0 break energy = new_energy d = d * max(0, gamma / previous_gamma) + r previous_gamma = gamma if name is not None: pp(i, energy=energy, energy_diff=energy_diff, norm=norm) if name is not None and info != -1: # only print if loop was terminated via `break` otherwise everything is pp(i, energy=energy, energy_diff=energy_diff, norm=norm) info = i if info == -1 else info return CGResults(x=pos, info=info, nit=i, nfev=nfev, success=info == 0) def _static_cg( mat, j, x0=None, *, absdelta=None, resnorm=None, norm_ord=None, tol=1e-5, # taken from SciPy's atol=0., miniter=None, maxiter=None, name=None, _raise_nonposdef=False, # TODO **kwargs ) -> CGResults: from jax.experimental.host_callback import call from jax.lax import cond, while_loop # While in general it is strongly discouraged to use `jit` inside a # function, we do it here to avoid the overhead of jitting through `mat` # multiple times within the CG loop. This is safe because the call to # `while_loop` in CG implies a JIT anyways. mat = jax.jit(mat) norm_ord = 2 if norm_ord is None else norm_ord # TODO: change to 1 maxiter_fallback = 20 * size(j) # taken from SciPy's NewtonCG minimzer miniter = jnp.minimum( 6, maxiter if maxiter is not None else maxiter_fallback ) if miniter is None else miniter maxiter = jnp.maximum( jnp.minimum(200, maxiter_fallback), miniter ) if maxiter is None else maxiter if absdelta is None and resnorm is None: # fallback convergence criterion resnorm = jnp.maximum(tol * jft_norm(j, ord=norm_ord), atol) common_dtp = result_type(j) eps = 6. * jnp.finfo(common_dtp).eps # taken from SciPy's NewtonCG minimzer tiny = 6. * jnp.finfo(common_dtp).tiny def pp(arg): _cg_pretty_print_it(name, **arg) def continue_condition(v): return v["info"] < -1 def cg_single_step(v): info = v["info"] pos, r, d, i = v["pos"], v["r"], v["d"], v["iteration"] previous_gamma, previous_energy = v["gamma"], v["energy"] i += 1 q = mat(d) curv = vdot(d, q) # ValueError("zero curvature in conjugate gradient") info = jnp.where(curv == 0., -1, info) alpha = previous_gamma / curv # ValueError("implausible gradient scaling `alpha < 0`") info = jnp.where(alpha < 0., -1, info) pos = pos - alpha * d r = cond( i % N_RESET == 0, lambda x: mat(x["pos"]) - x["j"], lambda x: x["r"] - x["q"] * x["alpha"], { "pos": pos, "j": j, "r": r, "q": q, "alpha": alpha } ) gamma = vdot(r, r) info = jnp.where( (gamma >= 0.) & (gamma <= tiny) & (info != -1), 0, info ) if resnorm is not None: norm = jft_norm(r, ord=norm_ord) info = jnp.where( (norm < resnorm) & (i >= miniter) & (info != -1), 0, info ) else: norm = None energy = vdot((r - j) / 2, pos) energy_diff = previous_energy - energy neg_energy_eps = -eps * jnp.abs(energy) # print(f"energy increased", file=sys.stderr) info = jnp.where(energy_diff < neg_energy_eps, -1, info) if absdelta is not None: info = jnp.where( (energy_diff < absdelta) & (i >= miniter) & (info != -1), 0, info ) info = jnp.where((i >= maxiter) & (info != -1), i, info) d = d * jnp.maximum(0, gamma / previous_gamma) + r if name is not None: printable_state = { "i": i, "energy": energy, "energy_diff": energy_diff, "absdelta": absdelta, "norm": norm, "resnorm": resnorm, "maxiter": maxiter } call(pp, printable_state, result_shape=None) ret = { "info": info, "pos": pos, "r": r, "d": d, "iteration": i, "gamma": gamma, "energy": energy } return ret if x0 is None: pos = zeros_like(j) r = -j d = r nfev = 0 else: pos = x0 r = mat(pos) - j d = r nfev = 1 # energy = .5xT M x - xT j energy = jnp.array(0.) if x0 is None else vdot((r - j) / 2, pos) gamma = vdot(r, r) val = { "info": jnp.array(-2, dtype=int), "pos": pos, "r": r, "d": d, "iteration": jnp.array(0), "gamma": gamma, "energy": energy } # Finish early if already converged in the initial iteration val["info"] = jnp.where(gamma == 0., 0, val["info"]) if name is not None: if resnorm is not None: norm = jft_norm(r, ord=norm_ord) else: norm = None printable_state = { "i": 0, "energy": energy, "energy_diff": jnp.inf, "absdelta": absdelta, "norm": norm, "resnorm": resnorm, "maxiter": maxiter } call(pp, printable_state, result_shape=None) val = while_loop(continue_condition, cg_single_step, val) i = val["iteration"] info = val["info"] nfev += i + i // N_RESET return CGResults( x=val["pos"], info=info, nit=i, nfev=nfev, success=info == 0 ) # The following is code adapted from Nicholas Mancuso to work with pytrees class _QuadSubproblemResult(NamedTuple): step: jnp.ndarray hits_boundary: Union[bool, jnp.ndarray] pred_f: Union[float, jnp.ndarray] nit: Union[int, jnp.ndarray] nfev: Union[int, jnp.ndarray] njev: Union[int, jnp.ndarray] nhev: Union[int, jnp.ndarray] success: Union[bool, jnp.ndarray] class _CGSteihaugState(NamedTuple): z: jnp.ndarray r: jnp.ndarray d: jnp.ndarray step: jnp.ndarray energy: Union[None, float, jnp.ndarray] hits_boundary: Union[bool, jnp.ndarray] done: Union[bool, jnp.ndarray] nit: Union[int, jnp.ndarray] nhev: Union[int, jnp.ndarray]
[docs] def second_order_approx( p: jnp.ndarray, cur_val: Union[float, jnp.ndarray], g: jnp.ndarray, hessp_at_xk: HessVP, ) -> Union[float, jnp.ndarray]: return cur_val + vdot(g, p) + 0.5 * vdot(p, hessp_at_xk(p))
[docs] def get_boundaries_intersections( z: jnp.ndarray, d: jnp.ndarray, trust_radius: Union[float, jnp.ndarray] ): # Adapted from SciPy """Solve the scalar quadratic equation ||z + t d|| == trust_radius. This is like a line-sphere intersection. Return the two values of t, sorted from low to high. """ a = vdot(d, d) b = 2 * vdot(z, d) c = vdot(z, z) - trust_radius**2 sqrt_discriminant = jnp.sqrt(b * b - 4 * a * c) # The following calculation is mathematically # equivalent to: # ta = (-b - sqrt_discriminant) / (2*a) # tb = (-b + sqrt_discriminant) / (2*a) # but produce smaller round off errors. # Look at Matrix Computation p.97 # for a better justification. aux = b + jnp.copysign(sqrt_discriminant, b) ta = -aux / (2 * a) tb = -2 * c / aux ra, rb = where(ta < tb, (ta, tb), (tb, ta)) return (ra, rb)
def _cg_steihaug_subproblem( cur_val: Union[float, jnp.ndarray], g: jnp.ndarray, hessp_at_xk: HessVP, *, trust_radius: Union[float, jnp.ndarray], tr_norm_ord: Union[None, int, float, jnp.ndarray] = None, resnorm: Optional[float], absdelta: Optional[float] = None, norm_ord: Union[None, int, float, jnp.ndarray] = None, miniter: Union[None, int] = None, maxiter: Union[None, int] = None, name=None ) -> _QuadSubproblemResult: """ Solve the subproblem using a conjugate gradient method. Parameters ---------- cur_val : Union[float, jnp.ndarray] Objective value evaluated at the current state. g : jnp.ndarray Gradient value evaluated at the current state. hessp_at_xk: Callable Function that accepts a proposal vector and computes the result of a Hessian-vector product. trust_radius : float Upper bound on how large a step proposal can be. tr_norm_ord : {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional Order of the norm for computing the length of the next step. norm_ord : {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional Order of the norm for testing convergence. Returns ------- result : _QuadSubproblemResult Contains the step proposal, whether it is at radius boundary, and meta-data regarding function calls and successful convergence. Notes ----- This is algorithm (7.2) of Nocedal and Wright 2nd edition. Only the function that computes the Hessian-vector product is required. The Hessian itself is not required, and the Hessian does not need to be positive semidefinite. """ from jax.experimental.host_callback import call from jax.lax import switch, while_loop tr_norm_ord = jnp.inf if tr_norm_ord is None else tr_norm_ord # taken from JAX norm_ord = 2 if norm_ord is None else norm_ord # TODO: change to 1 maxiter_fallback = 20 * size(g) # taken from SciPy's NewtonCG minimzer miniter = jnp.minimum( 6, maxiter if maxiter is not None else maxiter_fallback ) if miniter is None else miniter maxiter = jnp.maximum( jnp.minimum(200, maxiter_fallback), miniter ) if maxiter is None else maxiter common_dtp = result_type(g) eps = 6. * jnp.finfo( common_dtp ).eps # Inspired by SciPy's NewtonCG minimzer # second-order Taylor series approximation at the current values, gradient, # and hessian soa = partial( second_order_approx, cur_val=cur_val, g=g, hessp_at_xk=hessp_at_xk ) def pp(arg): msg = ( "{name}: |∇|:{r_norm:.6e} ➽:{resnorm:.6e} ↗:{tr:.6e}" " ☞:{case:1d} #∇²:{nhev:02d}" "\n{name}: Iteration {i} ⛰:{energy:+.6e} Δ⛰:{energy_diff:.6e}" + (" ➽:{absdelta:.6e}" if arg["absdelta"] is not None else "") + ( "\n{name}: Iteration Limit Reached" if arg["i"] == arg["maxiter"] else "" ) ), **arg)) # helpers for internal switches in the main CGSteihaug logic def noop( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param return iterp def step1( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param z, d, nhev = iterp.z, iterp.d, iterp.nhev ta, tb = get_boundaries_intersections(z, d, trust_radius) pa = z + ta * d pb = z + tb * d p_boundary = where(soa(pa) < soa(pb), pa, pb) return iterp._replace( step=p_boundary, nhev=nhev + 2, hits_boundary=True, done=True ) def step2( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param z, d = iterp.z, iterp.d ta, tb = get_boundaries_intersections(z, d, trust_radius) p_boundary = z + tb * d return iterp._replace(step=p_boundary, hits_boundary=True, done=True) def step3( param: Tuple[_CGSteihaugState, Union[float, jnp.ndarray]] ) -> _CGSteihaugState: iterp, z_next = param return iterp._replace(step=z_next, hits_boundary=False, done=True) # initialize the step p_origin = zeros_like(g) # init the state for the first iteration z = p_origin r = g d = -r energy = 0. init_param = _CGSteihaugState( z=z, r=r, d=d, step=p_origin, energy=energy, hits_boundary=False, done=maxiter == 0, nit=0, nhev=0 ) # Search for the min of the approximation of the objective function. def body_f(iterp: _CGSteihaugState) -> _CGSteihaugState: z, r, d = iterp.z, iterp.r, iterp.d energy, nit =, iterp.nit nit += 1 Bd = hessp_at_xk(d) dBd = vdot(d, Bd) r_squared = vdot(r, r) alpha = r_squared / dBd z_next = z + alpha * d r_next = r + alpha * Bd r_next_squared = vdot(r_next, r_next) beta_next = r_next_squared / r_squared d_next = -r_next + beta_next * d accept_z_next = nit >= maxiter if norm_ord == 2: r_next_norm = jnp.sqrt(r_next_squared) else: r_next_norm = jft_norm(r_next, ord=norm_ord) accept_z_next |= r_next_norm < resnorm # Relative to a plain CG, `z_next` is negative energy_next = vdot((r_next + g) / 2, z_next) energy_diff = energy - energy_next if absdelta is not None: neg_energy_eps = -eps * jnp.abs(energy) accept_z_next |= (energy_diff >= neg_energy_eps ) & (energy_diff < absdelta) & (nit >= miniter) # include a junk switch to catch the case where none should be executed z_next_norm = jft_norm(z_next, ord=tr_norm_ord) index = jnp.argmax( jnp.array( [False, dBd <= 0, z_next_norm >= trust_radius, accept_z_next] ) ) iterp = switch(index, [noop, step1, step2, step3], (iterp, z_next)) iterp = iterp._replace( z=z_next, r=r_next, d=d_next, energy=energy_next, nhev=iterp.nhev + 1, nit=nit ) if name is not None: printable_state = { "i": nit, "energy":, "energy_diff": energy_diff, "absdelta": absdelta, "tr": trust_radius, "r_norm": r_next_norm, "resnorm": resnorm, "nhev": iterp.nhev, "case": index, "maxiter": maxiter } call(pp, printable_state, result_shape=None) return iterp def cond_f(iterp: _CGSteihaugState) -> bool: return jnp.logical_not(iterp.done) # perform inner optimization to solve the constrained # quadratic subproblem using cg result = while_loop(cond_f, body_f, init_param) pred_f = soa(result.step) result = _QuadSubproblemResult( step=result.step, hits_boundary=result.hits_boundary, pred_f=pred_f, nit=result.nit, nfev=0, njev=0, nhev=result.nhev + 1, success=True ) return result