Skip to content

Commit

Permalink
Move pretrained weights to Releases
Browse files Browse the repository at this point in the history
  • Loading branch information
kazuto1011 committed Jun 29, 2020
1 parent 9b64d35 commit 4b71ed4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere
</tr>
<tr>
<td rowspan="2"><strong>This repo</strong></td>
<td rowspan="2"><a href='https://drive.google.com/file/d/1Cgbl3Q_tHPFPyqfx2hx-9FZYBSbG5Rhy/view?usp=sharing'>Download</a></td>
<td rowspan="2"><a href="https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff10k-20000.pth">Download</a></td>
<td></td>
<td><strong>65.8</td>
<td><strong>45.7</strong></td>
Expand All @@ -57,7 +57,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere
</td>
<td rowspan="2">164k <i>val</i></td>
<td rowspan="2"><strong>This repo</strong></td>
<td rowspan="2"><a href='https://drive.google.com/file/d/18kR928yl9Hz4xxuxnYgg7Hpi36hM8J2d/view?usp=sharing'>Download</a> &Dagger;</td>
<td rowspan="2"><a href="https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth">Download</a> &Dagger;</td>
<td></td>
<td>66.8</td>
<td>51.2</td>
Expand Down Expand Up @@ -112,7 +112,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere
</tr>
<tr>
<td rowspan="2"><strong>This repo</strong></td>
<td rowspan="2"><a href='https://drive.google.com/file/d/1FaW2Sp7Jj3eaoyZtbabM1IWZnuScN-u6/view?usp=sharing'>Download</a></td>
<td rowspan="2"><a href="https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-vocaug-20000.pth">Download</a></td>
<td></td>
<td>94.64</td>
<td>86.50</td>
Expand Down Expand Up @@ -240,7 +240,7 @@ python demo.py single \

To run on a webcam:

```console
```bash
python demo.py live \
--config-path configs/voc12.yaml \
--model-path deeplabv2_resnet101_msc-vocaug-20000.pth
Expand All @@ -252,12 +252,11 @@ To run a CRF post-processing, add `--crf`. To run on a CPU, add `--cpu`.

### torch.hub

Model setup with 3 lines
Model setup with two lines

```python
import torch.hub
model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=182)
model.load_state_dict(torch.load("deeplabv2_resnet101_msc-cocostuff164k-100000.pth"))
model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", pretrained='cocostuff164k', n_classes=182)
```

### Difference with Caffe version
Expand Down
49 changes: 27 additions & 22 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,41 @@

from __future__ import print_function

from torch.hub import load_state_dict_from_url

def deeplabv2_resnet101(pretrained=False, **kwargs):
"""
DeepLab v2 model with ResNet-101 backbone
n_classes (int): the number of classes
"""
model_url_root = "https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/"
model_dict = {
"cocostuff10k": ("deeplabv2_resnet101_msc-cocostuff10k-20000.pth", 182),
"cocostuff164k": ("deeplabv2_resnet101_msc-cocostuff164k-100000.pth", 182),
"voc12": ("deeplabv2_resnet101_msc-vocaug-20000.pth", 21),
}

if pretrained:
raise NotImplementedError(
"Please download from "
"https://github.com/kazuto1011/deeplab-pytorch/tree/master#performance"
)

def deeplabv2_resnet101(pretrained=None, n_classes=182, scales=None):

from libs.models.deeplabv2 import DeepLabV2
from libs.models.msc import MSC

base = DeepLabV2(n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24], **kwargs)
model = MSC(base=base, scales=[0.5, 0.75])
# Model parameters
n_blocks = [3, 4, 23, 3]
atrous_rates = [6, 12, 18, 24]
if scales is None:
scales = [0.5, 0.75]

return model
base = DeepLabV2(n_classes=n_classes, n_blocks=n_blocks, atrous_rates=atrous_rates)
model = MSC(base=base, scales=scales)

# Load pretrained models
if isinstance(pretrained, str):

if __name__ == "__main__":
import torch.hub
assert pretrained in model_dict, list(model_dict.keys())
expected = model_dict[pretrained][1]
error_message = "Expected: n_classes={}".format(expected)
assert n_classes == expected, error_message

model = torch.hub.load(
"kazuto1011/deeplab-pytorch",
"deeplabv2_resnet101",
n_classes=182,
force_reload=True,
)
model_url = model_url_root + model_dict[pretrained][0]
state_dict = load_state_dict_from_url(model_url)
model.load_state_dict(state_dict)

return model

print(model)

0 comments on commit 4b71ed4

Please sign in to comment.