From ec3bc6f74377b7b7f28b2158e8d836f97f5b0571 Mon Sep 17 00:00:00 2001 From: Edoardo Pasca Date: Thu, 13 Jul 2023 09:27:27 +0100 Subject: [PATCH] add unittests for CCPi Regularisation and SIRF objects --- Wrappers/Python/test/test_SIRF.py | 111 +++++++++++++++++++++++++++++- 1 file changed, 109 insertions(+), 2 deletions(-) diff --git a/Wrappers/Python/test/test_SIRF.py b/Wrappers/Python/test/test_SIRF.py index eaece897e..da1795fcc 100644 --- a/Wrappers/Python/test/test_SIRF.py +++ b/Wrappers/Python/test/test_SIRF.py @@ -28,7 +28,12 @@ from cil.optimisation.functions import TotalVariation, L2NormSquared, KullbackLeibler from cil.optimisation.algorithms import FISTA +import os +from cil.plugins.ccpi_regularisation.functions import FGP_TV, TGV, TNV, FGP_dTV +from cil.utilities.display import show2D + from testclass import CCPiTestClass +from utils import has_nvidia, has_ccpi_regularisation, initialise_tests initialise_tests() @@ -36,10 +41,17 @@ import sirf.STIR as pet import sirf.Gadgetron as mr from sirf.Utilities import examples_data_path + has_sirf = True except ImportError as ie: has_sirf = False +if has_ccpi_regularisation: + from ccpi.filters import regularisers + from cil.plugins.ccpi_regularisation.functions import FGP_TV, TGV, FGP_dTV, TNV + + + class KullbackLeiblerSIRF(object): def setUp(self): @@ -364,6 +376,101 @@ def test_BlockDataContainer_with_SIRF_DataContainer_subtract(self): +class CCPiRegularisationWithSIRFTests(CCPiTestClass): + + def setUpFGP_TV(self, max_iteration=100, alpha=1.): + return alpha*FGP_TV(max_iteration=max_iteration) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_FGP_TV_call_works(self): + regulariser = self.setUpFGP_TV() + output_number = regulariser(self.image1) + self.assertTrue(True) + # TODO: test the actual value + # expected = 160600016.0 + # np.testing.assert_allclose(output_number, expected, rtol=1e-5) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_FGP_TV_proximal_works(self): + regulariser = self.setUpFGP_TV() + solution = regulariser.proximal(x=self.image1, tau=1) + self.assertTrue(True) + + # TGV + def setUpTGV(self, max_iteration=100, alpha=1.): + return alpha * TGV(max_iteration=max_iteration) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TGV_call_works(self): + regulariser = self.setUpTGV() + output_number = regulariser(self.image1) + self.assertTrue(True) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TGV_proximal_works(self): + regulariser = self.setUpTGV() + solution = regulariser.proximal(x=self.image1, tau=1) + self.assertTrue(True) + + # dTV + def setUpdTV(self, max_iteration=100, alpha=1.): + return alpha * FGP_dTV(reference=self.image2, max_iteration=max_iteration) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TGV_call_works(self): + regulariser = self.setUpTGV() + output_number = regulariser(self.image1) + self.assertTrue(True) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TGV_proximal_works(self): + regulariser = self.setUpTGV() + solution = regulariser.proximal(x=self.image1, tau=1) + self.assertTrue(True) + + # TNV + def setUpTNV(self, max_iteration=100, alpha=1.): + return alpha * TNV(max_iteration=max_iteration) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TNV_call_works(self): + regulariser = self.setUpTNV() + output_number = regulariser(self.image1) + self.assertTrue(True) + + @unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation") + def test_TGV_proximal_works(self): + regulariser = self.setUpTNV() + solution = regulariser.proximal(x=self.image1, tau=1.) + self.assertTrue(True) + +class TestPETRegularisation(CCPiRegularisationWithSIRFTests): + skip_TNV_on_2D = True + def setUp(self): + self.image1 = pet.ImageData(os.path.join( + examples_data_path('PET'),'thorax_single_slice','emission.hv' + )) + self.image2 = self.image1 * 0.5 + + @unittest.skipIf(skip_TNV_on_2D, "TNV not implemented for 2D") + def test_TNV_call_works(self): + super().test_TNV_call_works() + + @unittest.skipIf(skip_TNV_on_2D, "TNV not implemented for 2D") + def test_TNV_proximal_works(self): + super().test_TNV_proximal_works() + +class TestRegRegularisation(CCPiRegularisationWithSIRFTests): + def setUp(self): + self.image1 = = reg.ImageData(os.path.join(examples_data_path('Registration'),'test2.nii.gz')) + self.image2 = self.image1 * 0.5 - - +class TestMRRegularisation(CCPiRegularisationWithSIRFTests): + def setUp(self): + acq_data = mr.AcquisitionData(os.path.join(examples_data_path('MR'),'simulated_MR_2D_cartesian.h5')) + preprocessed_data = mr.preprocess_acquisition_data(acq_data) + recon = mr.FullySampledReconstructor() + recon.set_input(preprocessed_data) + recon.process() + self.image1 = recon.get_output() + self.image2 = self.image1 * 0.5 \ No newline at end of file