Example 301: PDE Optim SciML Sensitivity

This example is used to show the integration of the package in the SciML ecosystem. For more details on the problem, look at the original formulation .

module Example301_PDEOptimSciMLSensitivity
using SkeelBerzins, DifferentialEquations

using DelimitedFiles
using Optimization, OptimizationPolyalgorithms, OptimizationOptimJL, SciMLSensitivity

function f(p)

    # Problem setup parameters:
    m = 0
    Lx = 10.0
    x = 0.0:0.01:Lx
    dx = x[2] - x[1]

    # Problem Parameters
    dt = 0.40 * dx^2    # CFL condition
    t0, tMax = 0.0, 1000 * dt
    tspan = (t0, tMax)
    t = t0:dt:tMax

    a0, a1 = p

    function pdefun(x, t, u, dudx)
        c = 1
        f = a1 * dudx
        s = 2.0 * a0 * u

        c, f, s
    end

    icfun(x) = exp(-(x - 3.0)^2)

    function bcfun(xl, ul, xr, ur, t)
        pl = ul
        ql = 0
        pr = ur
        qr = 0

        pl, ql, pr, qr
    end

    params_pdepe = SkeelBerzins.Params(; solver=:DiffEq, nb_design_var=length(p))

    pb = pdepe(m, pdefun, icfun, bcfun, collect(x), tspan; params=params_pdepe)
    prob = DifferentialEquations.ODEProblem(pb)
    sol = DifferentialEquations.solve(prob, ROS34PW1a(; linsolve=SparspakFactorization()); dt=dt, saveat=t)

    sol
end

function main()
    p = [1.0, 1.0] # True solution parameters
    sol_exact = Array(f(p))

    # Building the Prediction Model
    ps = [0.1, 0.2]  # Initial guess for model parameters
    predict(θ) = Array(f(θ))

    # Defining Loss function
    function loss(θ)
        pred = predict(θ)
        return sum(abs2.(pred .- sol_exact)), pred # Mean squared error
    end

    LOSS = []     # Loss accumulator
    PRED = []     # prediction accumulator
    PARS = []     # parameters accumulator

    callback = function (θ, l, pred) # callback function to observe training
        display(l)
        append!(PRED, [pred])
        append!(LOSS, l)
        append!(PARS, [θ])
        false
    end

    adtype = Optimization.AutoForwardDiff() # see https://docs.sciml.ai/Optimization/stable/API/optimization_function/#Automatic-Differentiation-Construction-Choice-Recommendations
    optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)

    optprob = Optimization.OptimizationProblem(optf, ps)
    res = Optimization.solve(optprob, NewtonTrustRegion(); allow_f_increases=true, callback=callback)

    return res.u
end

using Test

function runtests()
    testval = [1.0, 1.0]
    @test main() ≈ testval
end

end

This page was generated using Literate.jl.