Skip to content

Commit

Permalink
Support grounding dino (#10907)
Browse files Browse the repository at this point in the history
Co-authored-by: YanxingLiu <[email protected]>
  • Loading branch information
YanxingLiu and YanxingLiu committed Sep 18, 2023
1 parent d45bbda commit 073626f
Show file tree
Hide file tree
Showing 17 changed files with 1,556 additions and 30 deletions.
52 changes: 52 additions & 0 deletions configs/grounding_dino/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection

[GLIP: Grounded Language-Image Pre-training](https://arxiv.org/abs/2112.03857)

<!-- [ALGORITHM] -->

## Abstract

In this paper, we present an open-set object detector, called Grounding DINO, by marrying Transformer-based detector DINO with grounded pre-training, which can detect arbitrary objects with human inputs such as category names or referring expressions. The key solution of open-set object detection is introducing language to a closed-set detector for open-set concept generalization. To effectively fuse language and vision modalities, we conceptually divide a closed-set detector into three phases and propose a tight fusion solution, which includes a feature enhancer, a language-guided query selection, and a cross-modality decoder for cross-modality fusion. While previous works mainly evaluate open-set object detection on novel categories, we propose to also perform evaluations on referring expression comprehension for objects specified with attributes. Grounding DINO performs remarkably well on all three settings, including benchmarks on COCO, LVIS, ODinW, and RefCOCO/+/g. Grounding DINO achieves a 52.5 AP on the COCO detection zero-shot transfer benchmark, i.e., without any training data from COCO. It sets a new record on the ODinW zero-shot benchmark with a mean 26.1 AP.

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/42299757/0ed51aeb-3d53-42d8-8563-f6d21364ac95"/>
</div>

## Installation

```shell
cd $MMDETROOT

# source installation
pip install -r requirements/multimodal.txt

# or mim installation
mim install mmdet[multimodal]
```

```
cd $MMDETROOT
wget https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth
python demo/image_demo.py \
demo/demo.jpg \
configs/grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py \
--weights groundingdino_swint_ogc_mmdet-822d7e9d.pth \
--texts 'bench . car .'
```

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/42299757/3a3bd6f1-e2ed-43d4-aa22-0bb07ee6f20b"/>
</div>

## Results and Models

| Model | backbone | COCO mAP | Pre-Train Data | Config | Download |
| :--------------: | :------: | :------: | :----------------------------------------------: | :------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------: |
| Grounding DINO-T | Swin-T | 48.5 | O365,GoldG,Cap4M | [config](grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth) |
| Grounding DINO-B | Swin-B | 56.9 | COCO,O365,GoldG,Cap4M,OpenImage,ODinW-35,RefCOCO | [config](grounding_dino_swin-b_pretrain_mixeddata.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth) |

Note:

1. The weights corresponding to the zero-shot model are adopted from the official weights and converted using the [script](../../tools/model_converters/groundingdino_to_mmdet.py). We have not retrained the model for the time being.
16 changes: 16 additions & 0 deletions configs/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_base_ = [
'./grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py',
]

model = dict(
type='GroundingDINO',
backbone=dict(
pretrain_img_size=384,
embed_dims=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=12,
drop_path_rate=0.3,
patch_norm=True),
neck=dict(in_channels=[256, 512, 1024]),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

lang_model_name = 'bert-base-uncased'

model = dict(
type='GroundingDINO',
num_queries=900,
with_box_refine=True,
as_two_stage=True,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_mask=False,
),
language_model=dict(
type='BertModel',
name=lang_model_name,
pad_to_max=False,
use_sub_sentence_represent=True,
special_tokens_list=['[CLS]', '[SEP]', '.', '?'],
add_pooling_layer=True,
),
backbone=dict(
type='SwinTransformer',
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(1, 2, 3),
with_cp=False,
convert_weights=False),
neck=dict(
type='ChannelMapper',
in_channels=[192, 384, 768],
kernel_size=1,
out_channels=256,
act_cfg=None,
bias=True,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
encoder=dict(
num_layers=6,
# visual layer config
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_levels=4, dropout=0.0),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)),
# text layer config
text_layer_cfg=dict(
self_attn_cfg=dict(num_heads=4, embed_dims=256, dropout=0.0),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.0)),
# fusion layer config
fusion_layer_cfg=dict(
v_dim=256,
l_dim=256,
embed_dim=1024,
num_heads=4,
init_values=1e-4),
),
decoder=dict(
num_layers=6,
return_intermediate=True,
layer_cfg=dict(
# query self attention layer
self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0),
# cross attention layer query to text
cross_attn_text_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0),
# cross attention layer query to image
cross_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)),
post_norm_cfg=None),
positional_encoding=dict(
num_feats=128, normalize=True, offset=0.0, temperature=20),
bbox_head=dict(
type='GroundingDINOHead',
num_classes=80,
sync_cls_avg_factor=True,
max_text_len=256,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0), # 2.0 in DeformDETR
loss_bbox=dict(type='L1Loss', loss_weight=5.0)),
dn_cfg=dict( # TODO: Move to model.train_cfg ?
label_noise_scale=0.5,
box_noise_scale=1.0, # 0.4 for DN-DETR
group_cfg=dict(dynamic=True, num_groups=None,
num_dn_queries=100)), # TODO: half num_dn_queries
# training and testing settings
train_cfg=None,
test_cfg=dict(max_per_img=300))

test_pipeline = [
dict(
type='LoadImageFromFile', backend_args=None,
imdecode_backend='pillow'),
dict(
type='FixScaleResize',
scale=(800, 1333),
keep_ratio=True,
backend='pillow'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'text', 'custom_entities'))
]

val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, return_classes=True))
test_dataloader = val_dataloader
40 changes: 40 additions & 0 deletions configs/grounding_dino/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Collections:
- Name: Grounding DINO
Metadata:
Training Data: Objects365, GoldG, CC3M and COCO
Training Techniques:
- AdamW
- Multi Scale Train
- Gradient Clip
Training Resources: A100 GPUs
Architecture:
- Swin Transformer
- BERT
Paper:
URL: https://arxiv.org/abs/2303.05499
Title: 'Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection
'
README: configs/grounding_dino/README.md
Code:
URL:
Version: v3.0.0

Models:
- Name: grounding_dino_swin-t_pretrain_obj365_goldg_cap4m
In Collection: Grounding DINO
Config: configs/grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 48.5
Weights: https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth
- Name: grounding_dino_swin-b_pretrain_mixeddata
In Collection: GLIPGrounding DINO
Config: configs/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata.py
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 56.9
Weights: https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .ga_retina_head import GARetinaHead
from .ga_rpn_head import GARPNHead
from .gfl_head import GFLHead
from .grounding_dino_head import GroundingDINOHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .lad_head import LADHead
from .ld_head import LDHead
Expand Down Expand Up @@ -67,5 +68,5 @@
'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead', 'CondInstBboxHead',
'CondInstMaskHead', 'RTMDetInsHead', 'RTMDetInsSepBNHead',
'BoxInstBboxHead', 'BoxInstMaskHead', 'ConditionalDETRHead', 'DINOHead',
'ATSSVLFusionHead', 'DABDETRHead', 'DDQDETRHead'
'ATSSVLFusionHead', 'DABDETRHead', 'DDQDETRHead', 'GroundingDINOHead'
]
Loading

0 comments on commit 073626f

Please sign in to comment.