Skip to content

Commit

Permalink
Let CVode no longer depend on y0 after initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris00 committed Dec 30, 2023
1 parent 66b1033 commit 6337761
Showing 1 changed file with 30 additions and 33 deletions.
63 changes: 30 additions & 33 deletions src/cvode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
use std::{
ffi::{c_int, c_void, c_long},
pin::Pin,
ptr,
marker::PhantomData,
};
use sundials_sys::*;
use super::{
Expand All @@ -42,18 +42,18 @@ impl Drop for CVodeMem {
/// context, `V` is the type of vectors, `F` the type of the function
/// being used as right-hand side of the ODE, and `G` is the type of
/// the functions (if any) of which we want to seek roots.
pub struct CVode<'a, Ctx, V, F, G>
pub struct CVode<Ctx, V, F, G>
where V: Vector {
// One must take ownership of the context because it can only be
// used in a single ODE solver.
ctx: Ctx,
cvode_mem: CVodeMem,
t0: f64,
// FIXME: after used in CVodeInit, the initial vector is copied to
// the internal structures. No need to keep it.
y0: V::NVectorRef<'a>,
vec: PhantomData<V>,
rtol: f64,
atol: f64,
// We hold `Matrix` and `LinSolver` so they are freed when `CVode`
// is dropped.
matrix: Matrix,
linsolver: LinSolver,
rootsfound: Vec<c_int>, // cache, with len() == number of eq
Expand All @@ -69,7 +69,7 @@ struct UserData<F, G> {
g: G, // Function whose roots we want to compute (if any)
}

impl<'a, Ctx, V, F, G> CVode<'a, Ctx, V, F, G>
impl<Ctx, V, F, G> CVode<Ctx, V, F, G>
where Ctx: Context,
V: Vector
{
Expand All @@ -85,10 +85,9 @@ where Ctx: Context,
}

fn new_with_fn<G1>(ctx: Ctx, cvode_mem: CVodeMem, t0: f64,
y0: V::NVectorRef<'a>,
rtol: f64, atol: f64, matrix: Matrix, linsolver: LinSolver,
f: F, g: G1, ng: usize
) -> CVode<'a, Ctx, V, F, G1> {
) -> CVode<Ctx, V, F, G1> {
// FIXME: One can move `f` and `g` and set the user data
// before calling solving functions.
let user_data = Box::pin(UserData { f, g });
Expand All @@ -99,13 +98,13 @@ where Ctx: Context,
rootsfound.resize(ng, 0);
CVode {
ctx, cvode_mem,
t0, y0,
t0, vec: PhantomData,
rtol, atol, matrix, linsolver, rootsfound, user_data,
}
}
}

impl<'a, Ctx, V, F> CVode<'a, Ctx, V, F, ()>
impl<Ctx, V, F> CVode<Ctx, V, F, ()>
where Ctx: Context,
V: Vector + Sized,
F: FnMut(f64, V::View<'_>, V::ViewMut<'_>)
Expand All @@ -132,23 +131,20 @@ where Ctx: Context,
}

/// Initialize a CVode with method `llm`.
#[inline]
fn init(
name: &'static str, lmm: c_int,
// Take the context by move so only one solver can have it at
// a given time.
ctx: Ctx, t0: f64, y0: &'a V, f: F
// FIXME: Once y0 has been passed to CVodeInit, it is copied
// to interval structures and can be freed.
ctx: Ctx, t0: f64, y0: &V, f: F
) -> Result<Self, Error> {
let cvode_mem = unsafe {
CVodeMem(CVodeCreate(lmm, ctx.as_ptr())) };
if cvode_mem.0.is_null() {
return Err(Error::Fail{name, msg: "Allocation failed"})
}
let n = y0.len();
// Safety: `y0` will not outlive `ctx`: they are both in CVode
// and `y0` cannot not be moved out of this structure.
// SAFETY: Once `y0` has been passed to `CVodeInit`, it is
// copied to internal structures and thus can be freed.
let y0 = unsafe { y0.to_nvector(ctx.as_ptr()) };
let r = unsafe { CVodeInit(
cvode_mem.0,
Expand Down Expand Up @@ -178,27 +174,27 @@ where Ctx: Context,
return Err(Error::Fail{name, msg: "could not attach linear solver"})
}
Ok(Self::new_with_fn(
ctx, cvode_mem, t0, y0,
ctx, cvode_mem, t0,
rtol, atol, mat, linsolver,
f, (), 0))
}

/// Solver using the Adams linear multistep method. Recommended
/// for non-stiff problems.
// The fixed-point solver is recommended for nonstiff problems.
pub fn adams(ctx: Ctx, t0: f64, y0: &'a V, f: F) -> Result<Self, Error> {
pub fn adams(ctx: Ctx, t0: f64, y0: &V, f: F) -> Result<Self, Error> {
Self::init("CVode::adams", CV_ADAMS, ctx, t0, y0, f)
}

/// Solver using the BDF linear multistep method. Recommended for
/// stiff problems.
// The default Newton iteration is recommended for stiff problems,
pub fn bdf(ctx: Ctx, t0: f64, y0: &'a V, f: F) -> Result<Self, Error> {
pub fn bdf(ctx: Ctx, t0: f64, y0: &V, f: F) -> Result<Self, Error> {
Self::init("CVode::bdf", CV_BDF, ctx, t0, y0, f)
}
}

impl<'a, Ctx, V, F, G> CVode<'a, Ctx, V, F, G>
impl<Ctx, V, F, G> CVode<Ctx, V, F, G>
where V: Vector {
pub fn rtol(self, rtol: f64) -> Self {
unsafe { CVodeSStolerances(self.cvode_mem.0, rtol, self.atol); }
Expand Down Expand Up @@ -235,7 +231,7 @@ where V: Vector {
}

/// # Root-finding capabilities
impl<'a, Ctx, V, F, G> CVode<'a, Ctx, V, F, G>
impl<Ctx, V, F, G> CVode<Ctx, V, F, G>
where Ctx: Context,
V: Vector,
F: FnMut(f64, &V, &mut V) + Unpin,
Expand Down Expand Up @@ -265,7 +261,7 @@ where Ctx: Context,
/// 0 ≤ `i` < `N` (given by `g`(t,y, [g₁,...,gₙ])) are to be
/// found while the IVP is being solved.
///
pub fn root<const M: usize, G1>(self, g: G1) -> CVode<'a, Ctx, V, F, G1>
pub fn root<const M: usize, G1>(self, g: G1) -> CVode<Ctx, V, F, G1>
where G1: FnMut(f64, V::View<'_>, &mut [f64; M]) {
// FIXME: Do we want a second (because it will not work when V
// = [f64;N] since the number of equations is usually not the
Expand All @@ -282,7 +278,7 @@ where Ctx: Context,
let u = *Pin::into_inner(self.user_data);
Self::new_with_fn(
self.ctx,
self.cvode_mem, self.t0, self.y0,
self.cvode_mem, self.t0,
self.rtol, self.atol,
self.matrix, self.linsolver, u.f, g, M)
}
Expand All @@ -299,7 +295,7 @@ pub enum CV {
}

/// # Solving the IVP
impl<'a, Ctx, V, F, G> CVode<'a, Ctx, V, F, G>
impl<Ctx, V, F, G> CVode<Ctx, V, F, G>
where Ctx: Context,
V: Vector {
// FIXME: for arrays, it would be more convenient to return the
Expand All @@ -316,15 +312,16 @@ where Ctx: Context,
fn integrate(&mut self, t: f64, y: &mut V, itask: c_int) -> CV {
// Safety: `yout` does not escape this function and so will
// not outlive `self.ctx`.
let n = y.len();
//let n = y.len();
let yout = unsafe { V::to_nvector_mut(y, self.ctx.as_ptr()) };
if t == self.t0 {
unsafe { ptr::copy_nonoverlapping(
N_VGetArrayPointer(V::as_ptr(&self.y0) as *mut _),
N_VGetArrayPointer(V::as_mut_ptr(&yout)),
n) };
return CV::Ok;
}
// FIXME
// if t == self.t0 {
// unsafe { ptr::copy_nonoverlapping(
// N_VGetArrayPointer(V::as_ptr(&self.y0) as *mut _),
// N_VGetArrayPointer(V::as_mut_ptr(&yout)),
// n) };
// return CV::Ok;
// }
let mut t1 = self.t0;
let r = unsafe { CVode(
self.cvode_mem.0,
Expand Down Expand Up @@ -359,7 +356,7 @@ where Ctx: Context,

}

impl<'a, const N: usize, Ctx, F, G> CVode<'a, Ctx, [f64; N], F, G>
impl<const N: usize, Ctx, F, G> CVode<Ctx, [f64; N], F, G>
where Ctx: Context {
/// Return the solution at time `t`.
// FIXME: provide it for any type `V` — which must implement a
Expand Down

0 comments on commit 6337761

Please sign in to comment.