Skip to content

Commit

Permalink
Avoid killed kernel with badly defined analytic functions
Browse files Browse the repository at this point in the history
  • Loading branch information
bjorgve committed Jan 28, 2023
1 parent 01adce9 commit 3012534
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
17 changes: 17 additions & 0 deletions src/vampyr/tests/test_projector1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

from vampyr import vampyr1d as vp

def test_ScalingProjector():
def f(x):
return x

mra = vp.MultiResolutionAnalysis(box=[0, 1], order=7)
P_scaling = vp.ScalingProjector(mra, 2)
P_wavelet = vp.WaveletProjector(mra, 2)

with pytest.raises(Exception):
P_scaling(f)

with pytest.raises(Exception):
P_wavelet(f)
17 changes: 17 additions & 0 deletions src/vampyr/tests/test_projector3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

from vampyr import vampyr3d as vp

def test_ScalingProjector():
def f(x):
return x

mra = vp.MultiResolutionAnalysis(box=[0, 1], order=7)
P_scaling = vp.ScalingProjector(mra, 2)
P_wavelet = vp.WaveletProjector(mra, 2)

with pytest.raises(Exception):
P_scaling(f)

with pytest.raises(Exception):
P_wavelet(f)
22 changes: 21 additions & 1 deletion src/vampyr/treebuilders/project.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pragma once

#include <pybind11/functional.h>
#include <typeinfo>

#include <pybind11/functional.h>
#include "PyProjectors.h"

namespace vampyr {
Expand All @@ -18,6 +19,15 @@ template <int D> void project(pybind11::module &m) {
.def(
"__call__",
[](PyScalingProjector<D> &P, std::function<double(const Coord<D> &r)> func) {

try {
auto arr = std::array<double, D>();
arr.fill(111111.111); // A number which hopefully does not divide by zero
func(arr);
} catch (py::cast_error &e) {
py::print("Error: Invalid definition of analytic function");
throw;
}
auto old_threads = mrcpp_get_num_threads();
set_max_threads(1);
auto out = P(func);
Expand All @@ -33,6 +43,16 @@ template <int D> void project(pybind11::module &m) {
.def(
"__call__",
[](PyWaveletProjector<D> &P, std::function<double(const Coord<D> &r)> func) {

try {
auto arr = std::array<double, D>();
arr.fill(111111.111); // A number which hopefully does not divide by zero
func(arr);
} catch (py::cast_error &e) {
py::print("Error: Invalid definition of analytic function");
throw;
}

auto old_threads = mrcpp_get_num_threads();
set_max_threads(1);
auto out = P(func);
Expand Down

0 comments on commit 3012534

Please sign in to comment.