Example 2

Corresponds to Example 4.7 in [Ban2020_ParameterDependentMultilevel].

[1]:
import pymloc
import numpy as np
import jax.numpy as jnp
import jax
import warnings
warnings.filterwarnings('ignore')

Creating the object

First, we need to create a parameter dependent optimal control problem.

We need variables, objective and constraint.

Creating the variables object

The variables for the different levels are defined as follows.

[2]:
from pymloc.model.variables import InputStateVariables
from pymloc.model.variables import NullVariables
from pymloc.model.variables import ParameterContainer
from pymloc.model.variables.time_function import Time
from pymloc.model.domains import RNDomain


loc_vars = InputStateVariables(2, 1, time=Time(0., 2.))
hl_vars = ParameterContainer(2, domain=RNDomain(2))
variables2 = (hl_vars, loc_vars)

ll_vars = NullVariables()

Creating the control system

The parameter dependent control system is defined by

[3]:
from pymloc.model.control_system.parameter_dae import LinearParameterControlSystem

@jax.jit
def e(p, t):
    q = p[1]
    return jnp.array([[1., 0.], [q, 0.]])

@jax.jit
def a(p, t):
    q = p[1]
    return jnp.array([[-1., 0.], [-q, 1.]])

@jax.jit
def b(p, t):
    q = p[1]
    return jnp.array([[1.], [q]])

@jax.jit
def c(p, t):
    return jnp.identity(2)

@jax.jit
def d(p, t):
    return np.array([[0.]])

@jax.jit
def f(p, t):
    return np.array([0., 0.])



param_control = LinearParameterControlSystem(ll_vars, *variables2, e, a, b, c,
                                             d, f)

Creating the constraint object

[4]:
from pymloc.model.optimization.parameter_optimal_control import ParameterLQRConstraint


def initial_value(p):
    return np.array([2., 0.])



time = Time(0., 2.)

pdoc_constraint = ParameterLQRConstraint(*variables2, param_control,
                                         initial_value)

Creating the objective function

The objective function is defined by

[5]:
from pymloc.model.optimization.parameter_optimal_control import ParameterLQRObjective

@jax.jit
def q(p, t):
    return jnp.array([[p[0]**2. - 1., 0.], [0., 0.]])

@jax.jit
def s(p, t):
    return np.zeros((2, 1))

@jax.jit
def r(p, t):
    return jnp.array([[1.]])

@jax.jit
def m(p):
    return jnp.zeros((2, 2))



time = Time(0., 2.)
pdoc_objective = ParameterLQRObjective(*variables2, time, q, s, r, m)

Create the parameter dependent optimal control object

[6]:
from pymloc.model.optimization.parameter_optimal_control import ParameterDependentOptimalControl

pdoc_object = ParameterDependentOptimalControl(*variables2, pdoc_objective,
                                               pdoc_constraint)

The neccessary conditions can be obtained as follows

[7]:
parameters = np.array([2.,4.])
time = 3.

neccessary_conditions = pdoc_object.get_bvp()
e = neccessary_conditions.dynamical_system.e(parameters, time)
a = neccessary_conditions.dynamical_system.a(parameters, time)
print("E =\n {},\nA =\n {}".format(e, a))
E =
 [[ 0.  0.  1.  0.  0.]
 [ 0.  0.  4.  0.  0.]
 [-1. -4.  0.  0.  0.]
 [-0. -0.  0.  0.  0.]
 [-0. -0.  0.  0.  0.]],
A =
 [[ 0.  0. -1.  0.  1.]
 [ 0.  0. -4.  1.  4.]
 [-1. -4.  3.  0.  0.]
 [ 0.  1.  0.  0.  0.]
 [ 1.  4.  0.  0.  1.]]

Obtaining sensitivity values

Reference solution

A reference solution is given analytically and defined by

[8]:
def refsol(theta, t0, tf, t, x01):
    refsol = np.array(
        [[(1 / ((-1 + theta + np.exp(2 * (-t0 + tf) * theta) *
                 (1 + theta))**2)) * np.exp(-(t + t0) * theta) * x01 *
          (np.exp(-2 * (t0 - 2 * tf) * theta) * (1 + theta)**2 *
           (1 + t + t0 *
            (-1 + theta) - t * theta) - np.exp(2 * (t - t0 + tf) * theta) *
           (1 + theta)**2 *
           (1 + 2 * tf + t * (-1 + theta) + t0 *
            (-1 + theta) - 2 * tf * theta) - np.exp(2 * t * theta) * 2 *
           (-1 + theta)**2 * (1 + t * (1 + theta) - t0 *
                              (1 + theta)) + np.exp(2 * tf * theta) *
           (-1 + theta)**2 * (1 - t * (1 + theta) - t0 * (1 + theta) + 2 * tf *
                              (1 + theta)))],
         [(1 / ((-1 + theta + np.exp(2 * (-t0 + tf) * theta) *
                 (1 + theta))**2)) * np.exp(-(t + t0) * theta) * x01 *
          (np.exp(2 * t * theta) * (t - t0) *
           (-1 + theta)**2 + np.exp(-2 * t0 * theta + 4 * tf * theta) *
           (-t + t0) * (1 + theta)**2 + np.exp(2 * tf * theta) *
           (-2 + t + t0 - 2 * tf -
            (t + t0 - 2 * tf) * theta**2) + np.exp(2 * (t - t0 + tf) * theta) *
           (2 - t - t0 + 2 * tf + (t + t0 - 2 * tf) * theta**2))]])
    return refsol

Computed solution

Run the default sensitivities solver (adjoint computation) for different tolerances and collect data in results and iresults.

The summands of the adjoint solution are displayed in the last log message.

[9]:
import logging
logger = logging.getLogger()
logger.handlers[0].filters[0].__class__.max_level = 3

sens = pdoc_object.get_sensitivities()
sens.init_solver(abs_tol=1e-6, rel_tol=1e-6)
sol = sens.solve(parameters=np.array([2., 1.]), tau=1.)(1.)

rsol = refsol(2., 0., 2., 1., 2.)
ref = np.block([[rsol[0]], [0], [rsol[1]], [0], [-rsol[0]]])
ref = np.block([[ref, np.zeros((5, 1))]])

print("Absolute error: ", np.linalg.norm(ref - sol))
np.allclose(ref, sol, rtol=1e-9, atol=1e-2)
    Starting solver AdjointSensitivitiesSolver
    Current option values:
        abs_tol: 1e-06
        rel_tol: 1e-06
        max_iter: 10
    Compute sensitivity at tau = 1.0
            Starting solver MultipleShooting
            Current option values:
                abs_tol: 1.6666666666666665e-07
                rel_tol: 1.6666666666666665e-07
                max_iter: 10
            MultipleShooting solver initialized with

                        shooting_nodes: [0.  0.5 1.  1.5 2. ]

                        boundary_nodes: (0.0, 2.0)
                
            Computing solution in the interval (0.0, 0.5)
            Computing solution in the interval (0.5, 1.0)
            Computing solution in the interval (1.0, 1.5)
            Computing solution in the interval (1.5, 2.0)
            Computing inhomogeneous solution in the interval (0.0, 0.5)
            Computing inhomogeneous solution in the interval (0.5, 1.0)
            Computing inhomogeneous solution in the interval (1.0, 1.5)
            Computing inhomogeneous solution in the interval (1.5, 2.0)
            Computing inhomogeneous solution in the interval (0.0, 0.5)
            Computing inhomogeneous solution in the interval (0.5, 1.0)
            Computing inhomogeneous solution in the interval (1.0, 1.5)
            Computing inhomogeneous solution in the interval (1.5, 2.0)
    Assembling adjoint sensitivity boundary value problem...
    Solving adjoint boundary value problem...
            Starting solver MultipleShooting
            Current option values:
                abs_tol: 1.6666666666666665e-07
                rel_tol: 1.6666666666666665e-07
                max_iter: 10
            MultipleShooting solver initialized with

                        shooting_nodes: [0. 1. 2.]

                        boundary_nodes: (0.0, 1.0, 2.0)
                
            Computing solution in the interval (0.0, 1.0)
            Computing solution in the interval (1.0, 2.0)
            Computing inhomogeneous solution in the interval (0.0, 1.0)
            Computing inhomogeneous solution in the interval (1.0, 2.0)
            Computing inhomogeneous solution in the interval (0.0, 1.0)
            Computing inhomogeneous solution in the interval (1.0, 2.0)
            Computing inhomogeneous solution in the interval (0.0, 1.0)
            Computing inhomogeneous solution in the interval (1.0, 2.0)
    All summands:
        [[[ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]]

         [[ 2.86313431e-47  1.32841688e-01]
          [ 6.87427168e-47  7.37419503e-17]
          [-6.68522895e-48  3.02305565e-17]
          [ 2.71185104e-16  6.26624420e-17]
          [-6.22391196e-32 -1.32841688e-01]]

         [[ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]]

         [[ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  0.00000000e+00]]

         [[ 0.00000000e+00 -2.94967754e-17]
          [ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00 -3.02305503e-17]
          [ 0.00000000e+00  0.00000000e+00]
          [ 0.00000000e+00  2.94967754e-17]]

         [[ 1.00131499e-02 -1.32843861e-01]
          [ 0.00000000e+00  0.00000000e+00]
          [-2.74401376e-01  3.69198159e-07]
          [ 0.00000000e+00  0.00000000e+00]
          [-1.00131499e-02  1.32843861e-01]]]
Absolute error:  7.910407058878245e-06
[9]:
True