-
Notifications
You must be signed in to change notification settings - Fork 0
/
net.py
80 lines (74 loc) · 2.57 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from math import sqrt
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from utils import *
import os
from loss import *
from model import *
from skimage.feature.tests.test_orb import img
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
class Net(nn.Module):
def __init__(self, model_name, mode):
super(Net, self).__init__()
self.model_name = model_name
self.cal_loss = SoftIoULoss()
if model_name == 'DNANet':
if mode == 'train':
self.model = DNANet(mode='train')
else:
self.model = DNANet(mode='test')
elif model_name == 'DCANet':
if mode == 'train':
self.model = DCANet(mode='train')
else:
self.model = DCANet(mode='test')
print('mode_is=', mode)
elif model_name == 'DCAS6Net':
if mode == 'train':
self.model = DCAS6Net(mode='train')
else:
self.model = DCAS6Net(mode='test')
print('mode_is=', mode)
elif model_name == 'DCAS7Net':
if mode == 'train':
self.model = DCAS7Net(mode='train')
else:
self.model = DCAS7Net(mode='test')
print('mode_is=', mode)
elif model_name == 'DNANet_BY':
if mode == 'train':
self.model = DNAnet_BY(mode='train')
else:
self.model = DNAnet_BY(mode='test')
elif model_name == 'ACM':
self.model = ACM()
elif model_name == 'ALCNet':
self.model = ALCNet()
elif model_name == 'ISNet':
if mode == 'train':
self.model = ISNet(mode='train')
else:
self.model = ISNet(mode='test')
self.cal_loss = ISNetLoss()
elif model_name == 'RISTDnet':
self.model = RISTDnet()
elif model_name == 'UIUNet':
if mode == 'train':
self.model = UIUNet(mode='train')
else:
self.model = UIUNet(mode='test')
elif model_name == 'U-Net':
self.model = Unet()
elif model_name == 'ISTDU-Net':
self.model = ISTDU_Net()
elif model_name == 'RDIAN':
self.model = RDIAN()
self.model = self.model.cuda()
self.model = torch.nn.DataParallel(self.model)
def forward(self, img):
return self.model(img)
def loss(self, pred, gt_mask):
loss = self.cal_loss(pred, gt_mask)
return loss