From f740a4dc51f542b84244e11316d998274eb6bb6c Mon Sep 17 00:00:00 2001 From: Magnar Bjorgve Date: Wed, 27 Mar 2024 11:33:20 +0100 Subject: [PATCH 1/2] Refactor derivative operator creation with factory function This commit introduces a new factory function `Derivative` to the `derivatives` module. This function streamlines the creation of derivative operators by allowing users to specify the type of derivative operator they wish to create using a string identifier. The supported types are "center", "simple", "forward", "backward", "b-spline", and "ph", which correspond to the ABGVOperator with specific parameters and the BSOperator and PHOperator with a specified order. The factory function returns a unique pointer to the created derivative operator, ensuring proper memory management and simplifying the Python interface. Additionally, this commit removes the direct exposure of the specific derivative operator classes to the Python interface, encouraging the use of the new factory function for creating derivative operators. --- src/vampyr/operators/derivatives.h | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/vampyr/operators/derivatives.h b/src/vampyr/operators/derivatives.h index cd34e47f..c2e3a4ec 100644 --- a/src/vampyr/operators/derivatives.h +++ b/src/vampyr/operators/derivatives.h @@ -52,5 +52,26 @@ template void derivatives(pybind11::module &m) { )mydelimiter") // clang-format on .def(py::init &, int>(), "mra"_a, "order"_a = 1); + + + // Factory function to create derivative operators based on type + m.def("Derivative", [](const MultiResolutionAnalysis &mra, const std::string &type, int order = 1) -> std::unique_ptr> { + if (type == "center") { + return std::make_unique>(mra, 0.5, 0.5); + } else if (type == "simple") { + return std::make_unique>(mra, 0.0, 0.0); + } else if (type == "forward") { + return std::make_unique>(mra, 0.0, 1.0); + } else if (type == "backward") { + return std::make_unique>(mra, 1.0, 0.0); + } else if (type == "b-spline") { + return std::make_unique>(mra, order); + } else if (type == "ph") { + return std::make_unique>(mra, order); + } else { + throw std::invalid_argument("Unknown derivative type: " + type); + } + }, "mra"_a, "type"_a, "order"_a = 1); + } } // namespace vampyr From 2a87fae094987a2a6597d88eeea214ee15fd4c7e Mon Sep 17 00:00:00 2001 From: Stig Rune Jensen Date: Wed, 27 Mar 2024 17:17:52 +0100 Subject: [PATCH 2/2] Default to "center" derivative --- src/vampyr/operators/derivatives.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vampyr/operators/derivatives.h b/src/vampyr/operators/derivatives.h index c2e3a4ec..7592027e 100644 --- a/src/vampyr/operators/derivatives.h +++ b/src/vampyr/operators/derivatives.h @@ -55,7 +55,7 @@ template void derivatives(pybind11::module &m) { // Factory function to create derivative operators based on type - m.def("Derivative", [](const MultiResolutionAnalysis &mra, const std::string &type, int order = 1) -> std::unique_ptr> { + m.def("Derivative", [](const MultiResolutionAnalysis &mra, const std::string &type = "center", int order = 1) -> std::unique_ptr> { if (type == "center") { return std::make_unique>(mra, 0.5, 0.5); } else if (type == "simple") {