Skip to content

Commit

Permalink
Split code in modules; implement a type safe context
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris00 committed Dec 30, 2023
1 parent 76ca952 commit 2fb17eb
Show file tree
Hide file tree
Showing 9 changed files with 870 additions and 405 deletions.
7 changes: 4 additions & 3 deletions examples/basic.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use sundials::CVode;
use sundials::{context, CVode};

fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut ode = CVode::adams(0., &[0.],
|_t, _u, du| *du = [1.])?;
let ctx = context!().unwrap();
let mut ode = CVode::adams(ctx, 0., &[0.],
|_t, _u, du| *du = [1.])?;
let mut u1 = [f64::NAN];
ode.solve(1., &mut u1);
assert_eq!(u1[0], 1.);
Expand Down
6 changes: 6 additions & 0 deletions src/arkode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
//! Arkode is a solver for stiff, nonstiff, mixed stiff-nonstiff, and
//! multirate ODE systems based on Runge-Kutta methods. Includes
//! support for IMEX methods.

#[allow(dead_code)]
pub struct ARKode {}
361 changes: 361 additions & 0 deletions src/cvode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,361 @@
//! `Cvode` is a solver for stiff and nonstiff ODE systems ẏ = f(t,y)
//! with root detection and projection on a constraint manifold.
//! Based on Adams and BDF methods.

use std::{
ffi::{c_int, c_void, c_long},
pin::Pin,
ptr,
};
use sundials_sys::*;
use super::{
Context,
Error,
vector::Vector,
Matrix,
LinSolver
};

// Implement the Drop trait only on the pointer to be able to move
// values out of the structure `CVode`.
#[derive(Debug)]
struct CVodeMem(*mut c_void);

impl Drop for CVodeMem {
fn drop(&mut self) { unsafe { CVodeFree(&mut self.0) } }
}

/// Solver for stiff and nonstiff initial value problems for ODE systems.
///
/// The generic parameters are as follows: `Ctx` is the type of the
/// context, `V` is the type of vectors, `F` the type of the function
/// being used as right-hand side of the ODE, and `G` is the type of
/// the functions (if any) of which we want to seek roots.
pub struct CVode<'a, Ctx, V, F, G>
where V: Vector {
// One must take ownership of the context because it can only be
// used in a single ODE solver.
ctx: Ctx,
cvode_mem: CVodeMem,
t0: f64,
// FIXME: after used in CVodeInit, the initial vector is copied to
// the internal structures. No need to keep it.
y0: V::NVectorRef<'a>,
rtol: f64,
atol: f64,
matrix: Matrix,
linsolver: LinSolver,
rootsfound: Vec<c_int>, // cache, with len() == number of eq
user_data: Pin<Box<UserData<F, G>>>,
}

// The user-data may be updated according to the type of `G`.
// However, we must ensure that `f: F` is always extracted in the same
// way because `cvrhs` may only be passed during initialization.
#[repr(C)]
struct UserData<F, G> {
f: F, // Right-hand side of the equation
g: G, // Function whose roots we want to compute (if any)
}

impl<'a, Ctx, V, F, G> CVode<'a, Ctx, V, F, G>
where Ctx: Context,
V: Vector
{
/// Return a reference to the [`Context`] the CVode solver was
/// built with.
pub fn context(&self) -> &Ctx {
&self.ctx
}

/// Consumes CVode and return the [`Context`] it was built with.
pub fn into_context(self) -> Ctx {
self.ctx
}

fn new_with_fn<G1>(ctx: Ctx, cvode_mem: CVodeMem, t0: f64,
y0: V::NVectorRef<'a>,
rtol: f64, atol: f64, matrix: Matrix, linsolver: LinSolver,
f: F, g: G1, ng: usize
) -> CVode<'a, Ctx, V, F, G1> {
// FIXME: One can move `f` and `g` and set the user data
// before calling solving functions.
let user_data = Box::pin(UserData { f, g });
let user_data_ptr =
user_data.as_ref().get_ref() as *const _ as *mut c_void;
unsafe { CVodeSetUserData(cvode_mem.0, user_data_ptr) };
let mut rootsfound = Vec::with_capacity(ng);
rootsfound.resize(ng, 0);
CVode {
ctx, cvode_mem,
t0, y0,
rtol, atol, matrix, linsolver, rootsfound, user_data,
}
}
}

impl<'a, Ctx, V, F> CVode<'a, Ctx, V, F, ()>
where Ctx: Context,
V: Vector + Sized,
F: FnMut(f64, V::View<'_>, V::ViewMut<'_>)
{
/// Callback for the right-hand side of the equation.
extern "C" fn cvrhs(t: f64, nvy: N_Vector, nvdy: N_Vector,
user_data: *mut c_void
) -> c_int {
// Protect against unwinding in C code.
match std::panic::catch_unwind(|| {
// Get f from user_data, whatever the type of g.
let y = V::from_nvector(nvy);
let dy = V::from_nvector_mut(nvdy);
let u = unsafe { &mut *(user_data as *mut UserData<F, ()>) };
(u.f)(t, y, dy);
}) {
Ok(()) => 0,
Err(e) => {
eprintln!("sundials::CVode: right-hand side function \
panicked: {:?}", e);
std::process::abort() // Doesn't unwind
}
}
}

/// Initialize a CVode with method `llm`.
#[inline]
fn init(
name: &'static str, lmm: c_int,
// Take the context by move so only one solver can have it at
// a given time.
ctx: Ctx, t0: f64, y0: &'a V, f: F
// FIXME: Once y0 has been passed to CVodeInit, it is copied
// to interval structures and can be freed.
) -> Result<Self, Error> {
let cvode_mem = unsafe {
CVodeMem(CVodeCreate(lmm, ctx.as_ptr())) };
if cvode_mem.0.is_null() {
return Err(Error::Fail{name, msg: "Allocation failed"})
}
let n = y0.len();
// Safety: `y0` will not outlive `ctx`: they are both in CVode
// and `y0` cannot not be moved out of this structure.
let y0 = unsafe { y0.to_nvector(ctx.as_ptr()) };
let r = unsafe { CVodeInit(
cvode_mem.0,
Some(Self::cvrhs),
t0,
V::as_ptr(&y0) as *mut _) };
if r == CV_MEM_FAIL {
let msg = "a memory allocation request has failed";
return Err(Error::Fail{name, msg})
}
if r == CV_ILL_INPUT {
let msg = "An input argument has an illegal value";
return Err(Error::Fail{name, msg})
}
let rtol = 1e-6;
let atol = 1e-12;
// Set default tolerances (otherwise the solver will complain).
unsafe { CVodeSStolerances(
cvode_mem.0, rtol, atol); }
// Set the default linear solver (FIXME: configurable)
let mat = Matrix::new(name, &ctx, n, n)?;
let linsolver = unsafe { LinSolver::new(
name, ctx.as_ptr(), V::as_ptr(&y0) as *mut _, &mat)? };
let r = unsafe {
CVodeSetLinearSolver(cvode_mem.0, linsolver.0, mat.0) };
if r != CVLS_SUCCESS as i32 {
return Err(Error::Fail{name, msg: "could not attach linear solver"})
}
Ok(Self::new_with_fn(
ctx, cvode_mem, t0, y0,
rtol, atol, mat, linsolver,
f, (), 0))
}

/// Solver using the Adams linear multistep method. Recommended
/// for non-stiff problems.
// The fixed-point solver is recommended for nonstiff problems.
pub fn adams(ctx: Ctx, t0: f64, y0: &'a V, f: F) -> Result<Self, Error> {
Self::init("CVode::adams", CV_ADAMS, ctx, t0, y0, f)
}

/// Solver using the BDF linear multistep method. Recommended for
/// stiff problems.
// The default Newton iteration is recommended for stiff problems,
pub fn bdf(ctx: Ctx, t0: f64, y0: &'a V, f: F) -> Result<Self, Error> {
Self::init("CVode::bdf", CV_BDF, ctx, t0, y0, f)
}
}

impl<'a, Ctx, V, F, G> CVode<'a, Ctx, V, F, G>
where V: Vector {
pub fn rtol(self, rtol: f64) -> Self {
unsafe { CVodeSStolerances(self.cvode_mem.0, rtol, self.atol); }
self
}

pub fn atol(self, atol: f64) -> Self {
unsafe { CVodeSStolerances(self.cvode_mem.0, self.rtol, atol); }
self
}

pub fn maxord(self, o: u8) -> Self {
unsafe { CVodeSetMaxOrd(self.cvode_mem.0, o as _); }
self
}

/// Specify the maximum number of steps to be taken by the solver
/// in its attempt to reach the next output time. Default: 500.
// FIXME: make sure "mxstep steps taken before reaching tout" does
// not abort the program.
pub fn mxsteps(self, n: usize) -> Self {
let n =
if n <= c_long::MAX as usize { n as _ } else { c_long::MAX };
unsafe { CVodeSetMaxNumSteps(self.cvode_mem.0, n) };
self
}

/// Specifies the maximum number of messages issued by the solver
/// warning that t + h = t on the next internal step.
pub fn max_hnil_warns(self, n: usize) -> Self {
unsafe { CVodeSetMaxHnilWarns(self.cvode_mem.0, n as _) };
self
}
}

/// # Root-finding capabilities
impl<'a, Ctx, V, F, G> CVode<'a, Ctx, V, F, G>
where Ctx: Context,
V: Vector,
F: FnMut(f64, &V, &mut V) + Unpin,
G: Unpin
{
/// Callback for the root-finding callback for `N` functions,
/// where `N` is known at compile time.
extern "C" fn cvroot1<const N: usize, G1>(
t: f64, y: N_Vector, gout: *mut f64, user_data: *mut c_void) -> c_int
where G1: FnMut(f64, V::View<'_>, &mut [f64; N]) {
// Protect against unwinding in C code.
match std::panic::catch_unwind(|| {
let u = unsafe { &mut *(user_data as *mut UserData<F, G1>) };
let out = unsafe { &mut *(gout as *mut [f64; N]) };
(u.g)(t, V::from_nvector(y), out);
}) {
Ok(()) => 0,
Err(e) => {
eprintln!("sundials::CVode: function passed to .root() \
panicked: {:?}", e);
std::process::abort() // Doesn't unwind
}
}
}

/// Specifies that the roots of a set of functions gᵢ(t, y),
/// 0 ≤ `i` < `N` (given by `g`(t,y, [g₁,...,gₙ])) are to be
/// found while the IVP is being solved.
///
pub fn root<const M: usize, G1>(self, g: G1) -> CVode<'a, Ctx, V, F, G1>
where G1: FnMut(f64, V::View<'_>, &mut [f64; M]) {
// FIXME: Do we want a second (because it will not work when V
// = [f64;N] since the number of equations is usually not the
// same as the dimension of the problem) function accepting
// V::ViewMut and a dim as second parameter? In this case,
// one must wrap `g` to check for the dimension of the
// returned vector at each invocation.
let r = unsafe {
CVodeRootInit(self.cvode_mem.0, M as _,
Some(Self::cvroot1::<M, G1>)) };
if r == CV_MEM_FAIL {
panic!("Sundials::CVode::root: memory allocation failed.");
}
let u = *Pin::into_inner(self.user_data);
Self::new_with_fn(
self.ctx,
self.cvode_mem, self.t0, self.y0,
self.rtol, self.atol,
self.matrix, self.linsolver, u.f, g, M)
}
}


/// Return value of [`CVode::solve`] and [`CVode::step`].
#[derive(Debug, PartialEq)]
pub enum CV {
Ok,
Root(f64, Vec<bool>),
ErrFailure,
ConvFailure,
}

/// # Solving the IVP
impl<'a, Ctx, V, F, G> CVode<'a, Ctx, V, F, G>
where Ctx: Context,
V: Vector {
// FIXME: for arrays, it would be more convenient to return the
// array. For types that are not on the stack (i.e. not Copy),
// taking it as an additional parameter is better.
pub fn solve(&mut self, t: f64, y: &mut V) -> CV {
Self::integrate(self, t, y, CV_NORMAL)
}

pub fn step(&mut self, t: f64, y: &mut V) -> CV {
Self::integrate(self, t, y, CV_ONE_STEP)
}

fn integrate(&mut self, t: f64, y: &mut V, itask: c_int) -> CV {
// Safety: `yout` does not escape this function and so will
// not outlive `self.ctx`.
let n = y.len();
let yout = unsafe { V::to_nvector_mut(y, self.ctx.as_ptr()) };
if t == self.t0 {
unsafe { ptr::copy_nonoverlapping(
N_VGetArrayPointer(V::as_ptr(&self.y0) as *mut _),
N_VGetArrayPointer(V::as_mut_ptr(&yout)),
n) };
return CV::Ok;
}
let mut t1 = self.t0;
let r = unsafe { CVode(
self.cvode_mem.0,
t,
V::as_mut_ptr(&yout),
&mut t1, itask) };
match r {
CV_SUCCESS => CV::Ok,
//CV_TSTOP_RETURN => ,
CV_ROOT_RETURN => {
let ret = unsafe { CVodeGetRootInfo(
self.cvode_mem.0,
self.rootsfound.as_mut_ptr()) };
debug_assert_eq!(ret, CV_SUCCESS);
let z: c_int = 0;
let roots = self.rootsfound.iter().map(|g| g != &z).collect();
CV::Root(t1, roots)
}
CV_MEM_NULL | CV_NO_MALLOC => unreachable!(),
CV_ILL_INPUT => panic!("CV_ILL_INPUT"),
CV_TOO_MUCH_WORK => panic!("Too much work"),
CV_TOO_MUCH_ACC => panic!("Could not satisfy desired accuracy"),
CV_ERR_FAILURE => CV::ErrFailure,
CV_CONV_FAILURE => CV::ConvFailure,
CV_LINIT_FAIL => panic!("CV_LINIT_FAIL"),
CV_LSETUP_FAIL => panic!("CV_LSETUP_FAIL"),
CV_LSOLVE_FAIL => panic!("CV_LSOLVE_FAIL"),
CV_RTFUNC_FAIL => panic!("The root function failed"),
_ => panic!("sundials::CVode: unexpected return code {}", r),
}
}

}

impl<'a, const N: usize, Ctx, F, G> CVode<'a, Ctx, [f64; N], F, G>
where Ctx: Context {
/// Return the solution at time `t`.
// FIXME: provide it for any type `V` — which must implement a
// creation function.
pub fn solution(&mut self, t: f64) -> ([f64; N], CV) {
let mut y = [f64::NAN; N];
let cv = self.solve(t, &mut y);
(y, cv)
}
}
7 changes: 7 additions & 0 deletions src/cvodes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//! `Cvodes` is a solver for stiff and nonstiff ODE systems with
//! sensitivity analysis capabilities (forward and adjoint).

#[allow(dead_code)]
/// Solver stiff and nonstiff ODE systems with sensitivity analysis
/// capabilities.
pub struct Cvodes {}
6 changes: 6 additions & 0 deletions src/ida.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
//! `Ida` is a solver for differential-algebraic systems
//! F(t, y, ẏ) = 0 based on BDF methods.

#[allow(dead_code)]
/// Implicit differential-algebraic solver.
pub struct IDA {}
Loading

0 comments on commit 2fb17eb

Please sign in to comment.