Skip to content

Commit

Permalink
add unittests for CCPi Regularisation and SIRF objects
Browse files Browse the repository at this point in the history
  • Loading branch information
paskino committed Jul 13, 2023
1 parent 2c084aa commit ec3bc6f
Showing 1 changed file with 109 additions and 2 deletions.
111 changes: 109 additions & 2 deletions Wrappers/Python/test/test_SIRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,30 @@
from cil.optimisation.functions import TotalVariation, L2NormSquared, KullbackLeibler
from cil.optimisation.algorithms import FISTA

import os
from cil.plugins.ccpi_regularisation.functions import FGP_TV, TGV, TNV, FGP_dTV
from cil.utilities.display import show2D

from testclass import CCPiTestClass
from utils import has_nvidia, has_ccpi_regularisation, initialise_tests

initialise_tests()

try:
import sirf.STIR as pet
import sirf.Gadgetron as mr
from sirf.Utilities import examples_data_path

has_sirf = True
except ImportError as ie:
has_sirf = False

if has_ccpi_regularisation:
from ccpi.filters import regularisers
from cil.plugins.ccpi_regularisation.functions import FGP_TV, TGV, FGP_dTV, TNV



class KullbackLeiblerSIRF(object):

def setUp(self):
Expand Down Expand Up @@ -364,6 +376,101 @@ def test_BlockDataContainer_with_SIRF_DataContainer_subtract(self):



class CCPiRegularisationWithSIRFTests(CCPiTestClass):

def setUpFGP_TV(self, max_iteration=100, alpha=1.):
return alpha*FGP_TV(max_iteration=max_iteration)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_FGP_TV_call_works(self):
regulariser = self.setUpFGP_TV()
output_number = regulariser(self.image1)
self.assertTrue(True)
# TODO: test the actual value
# expected = 160600016.0
# np.testing.assert_allclose(output_number, expected, rtol=1e-5)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_FGP_TV_proximal_works(self):
regulariser = self.setUpFGP_TV()
solution = regulariser.proximal(x=self.image1, tau=1)
self.assertTrue(True)

# TGV
def setUpTGV(self, max_iteration=100, alpha=1.):
return alpha * TGV(max_iteration=max_iteration)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TGV_call_works(self):
regulariser = self.setUpTGV()
output_number = regulariser(self.image1)
self.assertTrue(True)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TGV_proximal_works(self):
regulariser = self.setUpTGV()
solution = regulariser.proximal(x=self.image1, tau=1)
self.assertTrue(True)

# dTV
def setUpdTV(self, max_iteration=100, alpha=1.):
return alpha * FGP_dTV(reference=self.image2, max_iteration=max_iteration)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TGV_call_works(self):
regulariser = self.setUpTGV()
output_number = regulariser(self.image1)
self.assertTrue(True)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TGV_proximal_works(self):
regulariser = self.setUpTGV()
solution = regulariser.proximal(x=self.image1, tau=1)
self.assertTrue(True)

# TNV
def setUpTNV(self, max_iteration=100, alpha=1.):
return alpha * TNV(max_iteration=max_iteration)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TNV_call_works(self):
regulariser = self.setUpTNV()
output_number = regulariser(self.image1)
self.assertTrue(True)

@unittest.skipUnless(has_sirf and has_ccpi_regularisation, "Has SIRF and CCPi Regularisation")
def test_TGV_proximal_works(self):
regulariser = self.setUpTNV()
solution = regulariser.proximal(x=self.image1, tau=1.)
self.assertTrue(True)

class TestPETRegularisation(CCPiRegularisationWithSIRFTests):
skip_TNV_on_2D = True
def setUp(self):
self.image1 = pet.ImageData(os.path.join(
examples_data_path('PET'),'thorax_single_slice','emission.hv'
))
self.image2 = self.image1 * 0.5

@unittest.skipIf(skip_TNV_on_2D, "TNV not implemented for 2D")
def test_TNV_call_works(self):
super().test_TNV_call_works()

@unittest.skipIf(skip_TNV_on_2D, "TNV not implemented for 2D")
def test_TNV_proximal_works(self):
super().test_TNV_proximal_works()

class TestRegRegularisation(CCPiRegularisationWithSIRFTests):
def setUp(self):
self.image1 = = reg.ImageData(os.path.join(examples_data_path('Registration'),'test2.nii.gz'))
self.image2 = self.image1 * 0.5



class TestMRRegularisation(CCPiRegularisationWithSIRFTests):
def setUp(self):
acq_data = mr.AcquisitionData(os.path.join(examples_data_path('MR'),'simulated_MR_2D_cartesian.h5'))
preprocessed_data = mr.preprocess_acquisition_data(acq_data)
recon = mr.FullySampledReconstructor()
recon.set_input(preprocessed_data)
recon.process()
self.image1 = recon.get_output()
self.image2 = self.image1 * 0.5

0 comments on commit ec3bc6f

Please sign in to comment.