Skip to content

Commit

Permalink
[CodeCamp2023-504]Add a new script to support the WBF
Browse files Browse the repository at this point in the history
Co-authored-by: huanghaian <[email protected]>
  • Loading branch information
Morty-Xu and hhaAndroid committed Sep 12, 2023
1 parent 59b0fc5 commit 769c810
Show file tree
Hide file tree
Showing 6 changed files with 754 additions and 1 deletion.
212 changes: 212 additions & 0 deletions demo/demo_multi_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Support for multi-model fusion, and currently only the Weighted Box Fusion
(WBF) fusion method is supported.
References: https://github.com/ZFTurbo/Weighted-Boxes-Fusion
Example:
python demo/demo_multi_model.py demo/demo.jpg \
./configs/faster_rcnn/faster-rcnn_r50-caffe_fpn_1x_coco.py \
./configs/retinanet/retinanet_r50-caffe_fpn_1x_coco.py \
--checkpoints \
https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco/faster_rcnn_r50_caffe_fpn_1x_coco_bbox_mAP-0.378_20200504_180032-c5925ee5.pth \ # noqa
https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_caffe_fpn_1x_coco/retinanet_r50_caffe_fpn_1x_coco_20200531-f11027c5.pth \
--weights 1 2
"""

import argparse
import os.path as osp

import mmcv
import mmengine
from mmengine.fileio import isdir, join_path, list_dir_or_file
from mmengine.logging import print_log
from mmengine.structures import InstanceData

from mmdet.apis import DetInferencer
from mmdet.models.utils import weighted_boxes_fusion
from mmdet.registry import VISUALIZERS
from mmdet.structures import DetDataSample

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')


def parse_args():
parser = argparse.ArgumentParser(
description='MMDetection multi-model inference demo')
parser.add_argument(
'inputs', type=str, help='Input image file or folder path.')
parser.add_argument(
'config',
type=str,
nargs='*',
help='Config file(s), support receive multiple files')
parser.add_argument(
'--checkpoints',
type=str,
nargs='*',
help='Checkpoint file(s), support receive multiple files, '
'remember to correspond to the above config',
)
parser.add_argument(
'--weights',
type=float,
nargs='*',
default=None,
help='weights for each model, remember to '
'correspond to the above config')
parser.add_argument(
'--fusion-iou-thr',
type=float,
default=0.55,
help='IoU value for boxes to be a match in wbf')
parser.add_argument(
'--skip-box-thr',
type=float,
default=0.0,
help='exclude boxes with score lower than this variable in wbf')
parser.add_argument(
'--conf-type',
type=str,
default='avg', # avg, max, box_and_model_avg, absent_model_aware_avg
help='how to calculate confidence in weighted boxes in wbf')
parser.add_argument(
'--out-dir',
type=str,
default='outputs',
help='Output directory of images or prediction results.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--pred-score-thr',
type=float,
default=0.3,
help='bbox score threshold')
parser.add_argument(
'--batch-size', type=int, default=1, help='Inference batch size.')
parser.add_argument(
'--show',
action='store_true',
help='Display the image in a popup window.')
parser.add_argument(
'--no-save-vis',
action='store_true',
help='Do not save detection vis results')
parser.add_argument(
'--no-save-pred',
action='store_true',
help='Do not save detection json results')
parser.add_argument(
'--palette',
default='none',
choices=['coco', 'voc', 'citys', 'random', 'none'],
help='Color palette used for visualization')

args = parser.parse_args()

if args.no_save_vis and args.no_save_pred:
args.out_dir = ''

return args


def main():
args = parse_args()

results = []
cfg_visualizer = None
dataset_meta = None

inputs = []
filename_list = []
if isdir(args.inputs):
dir = list_dir_or_file(
args.inputs, list_dir=False, suffix=IMG_EXTENSIONS)
for filename in dir:
img = mmcv.imread(join_path(args.inputs, filename))
inputs.append(img)
filename_list.append(filename)
else:
img = mmcv.imread(args.inputs)
inputs.append(img)
img_name = osp.basename(args.inputs)
filename_list.append(img_name)

for i, (config,
checkpoint) in enumerate(zip(args.config, args.checkpoints)):
inferencer = DetInferencer(
config, checkpoint, device=args.device, palette=args.palette)

result_raw = inferencer(
inputs=inputs,
batch_size=args.batch_size,
no_save_vis=True,
pred_score_thr=args.pred_score_thr)

if i == 0:
cfg_visualizer = inferencer.cfg.visualizer
dataset_meta = inferencer.model.dataset_meta
results = [{
'bboxes_list': [],
'scores_list': [],
'labels_list': []
} for _ in range(len(result_raw['predictions']))]

for res, raw in zip(results, result_raw['predictions']):
res['bboxes_list'].append(raw['bboxes'])
res['scores_list'].append(raw['scores'])
res['labels_list'].append(raw['labels'])

visualizer = VISUALIZERS.build(cfg_visualizer)
visualizer.dataset_meta = dataset_meta

for i in range(len(results)):
bboxes, scores, labels = weighted_boxes_fusion(
results[i]['bboxes_list'],
results[i]['scores_list'],
results[i]['labels_list'],
weights=args.weights,
iou_thr=args.fusion_iou_thr,
skip_box_thr=args.skip_box_thr,
conf_type=args.conf_type)

pred_instances = InstanceData()
pred_instances.bboxes = bboxes
pred_instances.scores = scores
pred_instances.labels = labels

fusion_result = DetDataSample(pred_instances=pred_instances)

img_name = filename_list[i]

if not args.no_save_pred:
out_json_path = (
args.out_dir + '/preds/' + img_name.split('.')[0] + '.json')
mmengine.dump(
{
'labels': labels.tolist(),
'scores': scores.tolist(),
'bboxes': bboxes.tolist()
}, out_json_path)

out_file = osp.join(args.out_dir, 'vis',
img_name) if not args.no_save_vis else None

visualizer.add_datasample(
img_name,
inputs[i][..., ::-1],
data_sample=fusion_result,
show=args.show,
draw_gt=False,
wait_time=0,
pred_score_thr=args.pred_score_thr,
out_file=out_file)

if not args.no_save_vis:
print_log(f'results have been saved at {args.out_dir}')


if __name__ == '__main__':
main()
74 changes: 74 additions & 0 deletions docs/en/user_guides/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,80 @@ python tools/analysis_tools/analyze_results.py \
--show-score-thr 0.3
```

## Fusing results from multiple models

`tools/analysis_tools/fusion_results.py` can fusing predictions using Weighted Boxes Fusion(WBF) from different object detection models. (Currently support coco format only)

**Usage**

```shell
python tools/analysis_tools/fuse_results.py \
${PRED_RESULTS} \
[--annotation ${ANNOTATION}] \
[--weights ${WEIGHTS}] \
[--fusion-iou-thr ${FUSION_IOU_THR}] \
[--skip-box-thr ${SKIP_BOX_THR}] \
[--conf-type ${CONF_TYPE}] \
[--eval-single ${EVAL_SINGLE}] \
[--save-fusion-results ${SAVE_FUSION_RESULTS}] \
[--out-dir ${OUT_DIR}]
```

Description of all arguments:

- `pred-results`: Paths of detection results from different models.(Currently support coco format only)
- `--annotation`: Path of ground-truth.
- `--weights`: List of weights for each model. Default: `None`, which means weight == 1 for each model.
- `--fusion-iou-thr`: IoU value for boxes to be a match。Default: `0.55`
- `--skip-box-thr`: The confidence threshold that needs to be excluded in the WBF algorithm. bboxes whose confidence is less than this value will be excluded.。Default: `0`
- `--conf-type`: How to calculate confidence in weighted boxes.
- `avg`: average value,default.
- `max`: maximum value.
- `box_and_model_avg`: box and model wise hybrid weighted average.
- `absent_model_aware_avg`: weighted average that takes into account the absent model.
- `--eval-single`: Whether evaluate every single model. Default: `False`.
- `--save-fusion-results`: Whether save fusion results. Default: `False`.
- `--out-dir`: Path of fusion results.

**Examples**:
Assume that you have got 3 result files from corresponding models through `tools/test.py`, which paths are './faster-rcnn_r50-caffe_fpn_1x_coco.json', './retinanet_r50-caffe_fpn_1x_coco.json', './cascade-rcnn_r50-caffe_fpn_1x_coco.json' respectively. The ground-truth file path is './annotation.json'.

1. Fusion of predictions from three models and evaluation of their effectiveness

```shell
python tools/analysis_tools/fuse_results.py \
./faster-rcnn_r50-caffe_fpn_1x_coco.json \
./retinanet_r50-caffe_fpn_1x_coco.json \
./cascade-rcnn_r50-caffe_fpn_1x_coco.json \
--annotation ./annotation.json \
--weights 1 2 3 \
```

2. Simultaneously evaluate each single model and fusion results

```shell
python tools/analysis_tools/fuse_results.py \
./faster-rcnn_r50-caffe_fpn_1x_coco.json \
./retinanet_r50-caffe_fpn_1x_coco.json \
./cascade-rcnn_r50-caffe_fpn_1x_coco.json \
--annotation ./annotation.json \
--weights 1 2 3 \
--eval-single
```

3. Fusion of prediction results from three models and save

```shell
python tools/analysis_tools/fuse_results.py \
./faster-rcnn_r50-caffe_fpn_1x_coco.json \
./retinanet_r50-caffe_fpn_1x_coco.json \
./cascade-rcnn_r50-caffe_fpn_1x_coco.json \
--annotation ./annotation.json \
--weights 1 2 3 \
--save-fusion-results \
--out-dir outputs/fusion
```

## Visualization

### Visualize Datasets
Expand Down
74 changes: 74 additions & 0 deletions docs/zh_cn/user_guides/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,80 @@ python tools/analysis_tools/analyze_results.py \
--show-score-thr 0.3
```

## 多模型检测结果融合

`tools/analysis_tools/fuse_results.py` 可使用 Weighted Boxes Fusion(WBF) 方法将多个模型的检测结果进行融合。(当前仅支持 COCO 格式)

**使用方法**

```shell
python tools/analysis_tools/fuse_results.py \
${PRED_RESULTS} \
[--annotation ${ANNOTATION}] \
[--weights ${WEIGHTS}] \
[--fusion-iou-thr ${FUSION_IOU_THR}] \
[--skip-box-thr ${SKIP_BOX_THR}] \
[--conf-type ${CONF_TYPE}] \
[--eval-single ${EVAL_SINGLE}] \
[--save-fusion-results ${SAVE_FUSION_RESULTS}] \
[--out-dir ${OUT_DIR}]
```

各个参数选项的作用:

- `pred-results`: 多模型测试结果的保存路径。(目前仅支持 json 格式)
- `--annotation`: 真实标注框的保存路径。
- `--weights`: 模型融合权重。默认设置下,每个模型的权重均为1。
- `--fusion-iou-thr`: 在WBF算法中,匹配成功的 IoU 阈值,默认值为`0.55`
- `--skip-box-thr`: WBF算法中需剔除的置信度阈值,置信度小于该值的 bbox 会被剔除,默认值为`0`
- `--conf-type`: 如何计算融合后 bbox 的置信度。有以下四种选项:
- `avg`: 取平均值,默认为此选项。
- `max`: 取最大值。
- `box_and_model_avg`: box和模型尺度的加权平均值。
- `absent_model_aware_avg`: 考虑缺失模型的加权平均值。
- `--eval-single`: 是否评估每个单一模型,默认值为`False`
- `--save-fusion-results`: 是否保存融合结果,默认值为`False`
- `--out-dir`: 融合结果保存的路径。

**样例**:
假设你已经通过 `tools/test.py` 得到了3个模型的 json 格式的结果文件,路径分别为 './faster-rcnn_r50-caffe_fpn_1x_coco.json', './retinanet_r50-caffe_fpn_1x_coco.json', './cascade-rcnn_r50-caffe_fpn_1x_coco.json',真实标注框的文件路径为'./annotation.json'

1. 融合三个模型的预测结果并评估其效果

```shell
python tools/analysis_tools/fuse_results.py \
./faster-rcnn_r50-caffe_fpn_1x_coco.json \
./retinanet_r50-caffe_fpn_1x_coco.json \
./cascade-rcnn_r50-caffe_fpn_1x_coco.json \
--annotation ./annotation.json \
--weights 1 2 3 \
```

2. 同时评估每个单一模型与融合结果

```shell
python tools/analysis_tools/fuse_results.py \
./faster-rcnn_r50-caffe_fpn_1x_coco.json \
./retinanet_r50-caffe_fpn_1x_coco.json \
./cascade-rcnn_r50-caffe_fpn_1x_coco.json \
--annotation ./annotation.json \
--weights 1 2 3 \
--eval-single
```

3. 融合三个模型的预测结果并保存

```shell
python tools/analysis_tools/fuse_results.py \
./faster-rcnn_r50-caffe_fpn_1x_coco.json \
./retinanet_r50-caffe_fpn_1x_coco.json \
./cascade-rcnn_r50-caffe_fpn_1x_coco.json \
--annotation ./annotation.json \
--weights 1 2 3 \
--save-fusion-results \
--out-dir outputs/fusion
```

## 可视化

### 可视化数据集
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .point_sample import (get_uncertain_point_coords_with_randomness,
get_uncertainty)
from .vlfuse_helper import BertEncoderLayer, VLFuse, permute_and_flatten
from .wbf import weighted_boxes_fusion

__all__ = [
'gaussian_radius', 'gen_gaussian_target', 'make_divisible',
Expand All @@ -32,5 +33,5 @@
'samplelist_boxtype2tensor', 'filter_gt_instances', 'rename_loss_dict',
'reweight_loss_dict', 'relative_coordinate_maps', 'aligned_bilinear',
'unfold_wo_center', 'imrenormalize', 'VLFuse', 'permute_and_flatten',
'BertEncoderLayer', 'align_tensor'
'BertEncoderLayer', 'align_tensor', 'weighted_boxes_fusion'
]
Loading

0 comments on commit 769c810

Please sign in to comment.