diff --git a/configs/grounding_dino/README.md b/configs/grounding_dino/README.md new file mode 100644 index 00000000000..4addc4f4d6d --- /dev/null +++ b/configs/grounding_dino/README.md @@ -0,0 +1,52 @@ +# Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection + +[GLIP: Grounded Language-Image Pre-training](https://arxiv.org/abs/2112.03857) + + + +## Abstract + +In this paper, we present an open-set object detector, called Grounding DINO, by marrying Transformer-based detector DINO with grounded pre-training, which can detect arbitrary objects with human inputs such as category names or referring expressions. The key solution of open-set object detection is introducing language to a closed-set detector for open-set concept generalization. To effectively fuse language and vision modalities, we conceptually divide a closed-set detector into three phases and propose a tight fusion solution, which includes a feature enhancer, a language-guided query selection, and a cross-modality decoder for cross-modality fusion. While previous works mainly evaluate open-set object detection on novel categories, we propose to also perform evaluations on referring expression comprehension for objects specified with attributes. Grounding DINO performs remarkably well on all three settings, including benchmarks on COCO, LVIS, ODinW, and RefCOCO/+/g. Grounding DINO achieves a 52.5 AP on the COCO detection zero-shot transfer benchmark, i.e., without any training data from COCO. It sets a new record on the ODinW zero-shot benchmark with a mean 26.1 AP. + +
+ +
+ +## Installation + +```shell +cd $MMDETROOT + +# source installation +pip install -r requirements/multimodal.txt + +# or mim installation +mim install mmdet[multimodal] +``` + +``` +cd $MMDETROOT + +wget https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth + +python demo/image_demo.py \ + demo/demo.jpg \ + configs/grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py \ + --weights groundingdino_swint_ogc_mmdet-822d7e9d.pth \ + --texts 'bench . car .' +``` + +
+ +
+ +## Results and Models + +| Model | backbone | COCO mAP | Pre-Train Data | Config | Download | +| :--------------: | :------: | :------: | :----------------------------------------------: | :------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------: | +| Grounding DINO-T | Swin-T | 48.5 | O365,GoldG,Cap4M | [config](grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth) | +| Grounding DINO-B | Swin-B | 56.9 | COCO,O365,GoldG,Cap4M,OpenImage,ODinW-35,RefCOCO | [config](grounding_dino_swin-b_pretrain_mixeddata.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth) | + +Note: + +1. The weights corresponding to the zero-shot model are adopted from the official weights and converted using the [script](../../tools/model_converters/groundingdino_to_mmdet.py). We have not retrained the model for the time being. diff --git a/configs/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata.py b/configs/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata.py new file mode 100644 index 00000000000..92f327fef83 --- /dev/null +++ b/configs/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata.py @@ -0,0 +1,16 @@ +_base_ = [ + './grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py', +] + +model = dict( + type='GroundingDINO', + backbone=dict( + pretrain_img_size=384, + embed_dims=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=12, + drop_path_rate=0.3, + patch_norm=True), + neck=dict(in_channels=[256, 512, 1024]), +) diff --git a/configs/grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py b/configs/grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py new file mode 100644 index 00000000000..41069e29035 --- /dev/null +++ b/configs/grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py @@ -0,0 +1,127 @@ +_base_ = [ + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +lang_model_name = 'bert-base-uncased' + +model = dict( + type='GroundingDINO', + num_queries=900, + with_box_refine=True, + as_two_stage=True, + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=False, + ), + language_model=dict( + type='BertModel', + name=lang_model_name, + pad_to_max=False, + use_sub_sentence_represent=True, + special_tokens_list=['[CLS]', '[SEP]', '.', '?'], + add_pooling_layer=True, + ), + backbone=dict( + type='SwinTransformer', + embed_dims=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + patch_norm=True, + out_indices=(1, 2, 3), + with_cp=False, + convert_weights=False), + neck=dict( + type='ChannelMapper', + in_channels=[192, 384, 768], + kernel_size=1, + out_channels=256, + act_cfg=None, + bias=True, + norm_cfg=dict(type='GN', num_groups=32), + num_outs=4), + encoder=dict( + num_layers=6, + # visual layer config + layer_cfg=dict( + self_attn_cfg=dict(embed_dims=256, num_levels=4, dropout=0.0), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)), + # text layer config + text_layer_cfg=dict( + self_attn_cfg=dict(num_heads=4, embed_dims=256, dropout=0.0), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=1024, ffn_drop=0.0)), + # fusion layer config + fusion_layer_cfg=dict( + v_dim=256, + l_dim=256, + embed_dim=1024, + num_heads=4, + init_values=1e-4), + ), + decoder=dict( + num_layers=6, + return_intermediate=True, + layer_cfg=dict( + # query self attention layer + self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + # cross attention layer query to text + cross_attn_text_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + # cross attention layer query to image + cross_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=2048, ffn_drop=0.0)), + post_norm_cfg=None), + positional_encoding=dict( + num_feats=128, normalize=True, offset=0.0, temperature=20), + bbox_head=dict( + type='GroundingDINOHead', + num_classes=80, + sync_cls_avg_factor=True, + max_text_len=256, + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), # 2.0 in DeformDETR + loss_bbox=dict(type='L1Loss', loss_weight=5.0)), + dn_cfg=dict( # TODO: Move to model.train_cfg ? + label_noise_scale=0.5, + box_noise_scale=1.0, # 0.4 for DN-DETR + group_cfg=dict(dynamic=True, num_groups=None, + num_dn_queries=100)), # TODO: half num_dn_queries + # training and testing settings + train_cfg=None, + test_cfg=dict(max_per_img=300)) + +test_pipeline = [ + dict( + type='LoadImageFromFile', backend_args=None, + imdecode_backend='pillow'), + dict( + type='FixScaleResize', + scale=(800, 1333), + keep_ratio=True, + backend='pillow'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'text', 'custom_entities')) +] + +val_dataloader = dict( + dataset=dict(pipeline=test_pipeline, return_classes=True)) +test_dataloader = val_dataloader diff --git a/configs/grounding_dino/metafile.yml b/configs/grounding_dino/metafile.yml new file mode 100644 index 00000000000..86a0858d690 --- /dev/null +++ b/configs/grounding_dino/metafile.yml @@ -0,0 +1,40 @@ +Collections: + - Name: Grounding DINO + Metadata: + Training Data: Objects365, GoldG, CC3M and COCO + Training Techniques: + - AdamW + - Multi Scale Train + - Gradient Clip + Training Resources: A100 GPUs + Architecture: + - Swin Transformer + - BERT + Paper: + URL: https://arxiv.org/abs/2303.05499 + Title: 'Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection +' + README: configs/grounding_dino/README.md + Code: + URL: + Version: v3.0.0 + +Models: + - Name: grounding_dino_swin-t_pretrain_obj365_goldg_cap4m + In Collection: Grounding DINO + Config: configs/grounding_dino/grounding_dino_swin-t_pretrain_obj365_goldg_cap4m.py + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 48.5 + Weights: https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth + - Name: grounding_dino_swin-b_pretrain_mixeddata + In Collection: GLIPGrounding DINO + Config: configs/grounding_dino/grounding_dino_swin-b_pretrain_mixeddata.py + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 56.9 + Weights: https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swinb_cogcoor_mmdet-55949c9c.pth diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 2143d93d854..c9b55ec2a42 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -26,6 +26,7 @@ from .ga_retina_head import GARetinaHead from .ga_rpn_head import GARPNHead from .gfl_head import GFLHead +from .grounding_dino_head import GroundingDINOHead from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead from .lad_head import LADHead from .ld_head import LDHead @@ -67,5 +68,5 @@ 'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead', 'CondInstBboxHead', 'CondInstMaskHead', 'RTMDetInsHead', 'RTMDetInsSepBNHead', 'BoxInstBboxHead', 'BoxInstMaskHead', 'ConditionalDETRHead', 'DINOHead', - 'ATSSVLFusionHead', 'DABDETRHead', 'DDQDETRHead' + 'ATSSVLFusionHead', 'DABDETRHead', 'DDQDETRHead', 'GroundingDINOHead' ] diff --git a/mmdet/models/dense_heads/grounding_dino_head.py b/mmdet/models/dense_heads/grounding_dino_head.py new file mode 100644 index 00000000000..d3ca2baf088 --- /dev/null +++ b/mmdet/models/dense_heads/grounding_dino_head.py @@ -0,0 +1,321 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import Linear +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox_cxcywh_to_xyxy +from mmdet.utils import InstanceList +from ..layers import inverse_sigmoid +from .atss_vlfusion_head import convert_grounding_to_cls_scores +from .dino_head import DINOHead + + +class ContrastiveEmbed(nn.Module): + """text visual ContrastiveEmbed layer. + + Args: + max_text_len (int, optional): Maximum length of text. + """ + + def __init__(self, max_text_len=256): + super().__init__() + self.max_text_len = max_text_len + + def forward(self, visual_feat: Tensor, text_feat: Tensor, + text_token_mask: Tensor) -> Tensor: + """Forward function. + + Args: + visual_feat (Tensor): Visual features. + text_feat (Tensor): Text features. + text_token_mask (Tensor): A mask used for text feats. + + Returns: + Tensor: Classification score. + """ + res = visual_feat @ text_feat.transpose(-1, -2) + res.masked_fill_(~text_token_mask[:, None, :], float('-inf')) + + new_res = torch.full((*res.shape[:-1], self.max_text_len), + float('-inf'), + device=res.device) + new_res[..., :res.shape[-1]] = res + + return new_res + + +@MODELS.register_module() +class GroundingDINOHead(DINOHead): + """Head of the Grounding DINO: Marrying DINO with Grounded Pre-Training for + Open-Set Object Detection. + + Args: + max_text_len (int, optional): Maximum length of text. + """ + + def __init__(self, max_text_len=256, **kwargs): + + self.max_text_len = max_text_len + super().__init__(**kwargs) + + def _init_layers(self) -> None: + """Initialize classification branch and regression branch of head.""" + fc_cls = ContrastiveEmbed(self.max_text_len) + reg_branch = [] + for _ in range(self.num_reg_fcs): + reg_branch.append(Linear(self.embed_dims, self.embed_dims)) + reg_branch.append(nn.ReLU()) + reg_branch.append(Linear(self.embed_dims, 4)) + reg_branch = nn.Sequential(*reg_branch) + + # NOTE: due to the fc_cls is a contrastive embedding and don't + # have any trainable parameters,we do not need to copy it. + if self.share_pred_layer: + self.cls_branches = nn.ModuleList( + [fc_cls for _ in range(self.num_pred_layer)]) + self.reg_branches = nn.ModuleList( + [reg_branch for _ in range(self.num_pred_layer)]) + else: + self.cls_branches = nn.ModuleList( + [fc_cls for _ in range(self.num_pred_layer)]) + self.reg_branches = nn.ModuleList([ + copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer) + ]) + + def forward( + self, + hidden_states: Tensor, + references: List[Tensor], + memory_text: Tensor, + text_token_mask: Tensor, + ) -> Tuple[Tensor]: + """Forward function. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, bs, num_queries, dim). + references (List[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries, 4) when `as_two_stage` of the detector is `True`, + otherwise (bs, num_queries, 2). Each `inter_reference` has + shape (bs, num_queries, 4) when `with_box_refine` of the + detector is `True`, otherwise (bs, num_queries, 2). The + coordinates are arranged as (cx, cy) when the last dimension is + 2, and (cx, cy, w, h) when it is 4. + memory_text (Tensor): Memory text. It has shape (bs, len_text, + text_embed_dims). + text_token_mask (Tensor): Text token mask. It has shape (bs, + len_text). + + Returns: + tuple[Tensor]: results of head containing the following tensor. + + - all_layers_outputs_classes (Tensor): Outputs from the + classification head, has shape (num_decoder_layers, bs, + num_queries, cls_out_channels). + - all_layers_outputs_coords (Tensor): Sigmoid outputs from the + regression head with normalized coordinate format (cx, cy, w, + h), has shape (num_decoder_layers, bs, num_queries, 4) with the + last dimension arranged as (cx, cy, w, h). + """ + all_layers_outputs_classes = [] + all_layers_outputs_coords = [] + + for layer_id in range(hidden_states.shape[0]): + reference = inverse_sigmoid(references[layer_id]) + # NOTE The last reference will not be used. + hidden_state = hidden_states[layer_id] + outputs_class = self.cls_branches[layer_id](hidden_state, + memory_text, + text_token_mask) + tmp_reg_preds = self.reg_branches[layer_id](hidden_state) + if reference.shape[-1] == 4: + # When `layer` is 0 and `as_two_stage` of the detector + # is `True`, or when `layer` is greater than 0 and + # `with_box_refine` of the detector is `True`. + tmp_reg_preds += reference + else: + # When `layer` is 0 and `as_two_stage` of the detector + # is `False`, or when `layer` is greater than 0 and + # `with_box_refine` of the detector is `False`. + assert reference.shape[-1] == 2 + tmp_reg_preds[..., :2] += reference + outputs_coord = tmp_reg_preds.sigmoid() + all_layers_outputs_classes.append(outputs_class) + all_layers_outputs_coords.append(outputs_coord) + + all_layers_outputs_classes = torch.stack(all_layers_outputs_classes) + all_layers_outputs_coords = torch.stack(all_layers_outputs_coords) + + return all_layers_outputs_classes, all_layers_outputs_coords + + def predict(self, + hidden_states: Tensor, + references: List[Tensor], + memory_text: Tensor, + text_token_mask: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + """Perform forward propagation and loss calculation of the detection + head on the queries of the upstream network. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, num_queries, bs, dim). + references (List[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries, 4) when `as_two_stage` of the detector is `True`, + otherwise (bs, num_queries, 2). Each `inter_reference` has + shape (bs, num_queries, 4) when `with_box_refine` of the + detector is `True`, otherwise (bs, num_queries, 2). The + coordinates are arranged as (cx, cy) when the last dimension is + 2, and (cx, cy, w, h) when it is 4. + memory_text (Tensor): Memory text. It has shape (bs, len_text, + text_embed_dims). + text_token_mask (Tensor): Text token mask. It has shape (bs, + len_text). + batch_data_samples (SampleList): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): If `True`, return boxes in original + image space. Defaults to `True`. + + Returns: + InstanceList: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + batch_token_positive_maps = [ + data_samples.token_positive_map + for data_samples in batch_data_samples + ] + + outs = self(hidden_states, references, memory_text, text_token_mask) + + predictions = self.predict_by_feat( + *outs, + batch_img_metas=batch_img_metas, + batch_token_positive_maps=batch_token_positive_maps, + rescale=rescale) + return predictions + + def predict_by_feat(self, + all_layers_cls_scores: Tensor, + all_layers_bbox_preds: Tensor, + batch_img_metas: List[Dict], + batch_token_positive_maps: Optional[List[dict]] = None, + rescale: bool = False) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Args: + all_layers_cls_scores (Tensor): Classification scores of all + decoder layers, has shape (num_decoder_layers, bs, num_queries, + cls_out_channels). + all_layers_bbox_preds (Tensor): Regression outputs of all decoder + layers. Each is a 4D-tensor with normalized coordinate format + (cx, cy, w, h) and shape (num_decoder_layers, bs, num_queries, + 4) with the last dimension arranged as (cx, cy, w, h). + batch_img_metas (List[Dict]): _description_ + batch_token_positive_maps (list[dict], Optional): Batch token + positive map. Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + cls_scores = all_layers_cls_scores[-1] + bbox_preds = all_layers_bbox_preds[-1] + result_list = [] + for img_id in range(len(batch_img_metas)): + cls_score = cls_scores[img_id] + bbox_pred = bbox_preds[img_id] + img_meta = batch_img_metas[img_id] + token_positive_maps = batch_token_positive_maps[img_id] + results = self._predict_by_feat_single(cls_score, bbox_pred, + token_positive_maps, + img_meta, rescale) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score: Tensor, + bbox_pred: Tensor, + token_positive_maps: dict, + img_meta: dict, + rescale: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score (Tensor): Box score logits from the last decoder layer + for each image. Shape [num_queries, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from the last decoder layer + for each image, with coordinate format (cx, cy, w, h) and + shape [num_queries, 4]. + token_positive_maps (dict): Token positive map. + img_meta (dict): Image meta info. + rescale (bool, optional): If True, return boxes in original image + space. Default True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_score) == len(bbox_pred) # num_queries + max_per_img = self.test_cfg.get('max_per_img', len(cls_score)) + img_shape = img_meta['img_shape'] + + cls_score = convert_grounding_to_cls_scores( + logits=cls_score.sigmoid()[None], + positive_maps=[token_positive_maps])[0] + scores, indexes = cls_score.view(-1).topk(max_per_img) + num_classes = cls_score.shape[-1] + det_labels = indexes % num_classes + bbox_index = indexes // num_classes + bbox_pred = bbox_pred[bbox_index] + + det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred) + det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] + det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0] + det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1]) + det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0]) + if rescale: + assert img_meta.get('scale_factor') is not None + det_bboxes /= det_bboxes.new_tensor( + img_meta['scale_factor']).repeat((1, 2)) + results = InstanceData() + results.bboxes = det_bboxes + results.scores = scores + results.labels = det_labels + return results diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index bc1ff257da4..e5a06d2813c 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -25,6 +25,7 @@ from .gfl import GFL from .glip import GLIP from .grid_rcnn import GridRCNN +from .grounding_dino import GroundingDINO from .htc import HybridTaskCascade from .kd_one_stage import KnowledgeDistillationSingleStageDetector from .lad import LAD @@ -70,5 +71,5 @@ 'MaskFormer', 'DDOD', 'Mask2Former', 'SemiBaseDetector', 'SoftTeacher', 'RTMDet', 'Detectron2Wrapper', 'CrowdDet', 'CondInst', 'BoxInst', 'DetectionTransformer', 'ConditionalDETR', 'DINO', 'DABDETR', 'GLIP', - 'DDQDETR' + 'DDQDETR', 'GroundingDINO' ] diff --git a/mmdet/models/detectors/dino.py b/mmdet/models/detectors/dino.py index a4385462aff..ade47f531d2 100644 --- a/mmdet/models/detectors/dino.py +++ b/mmdet/models/detectors/dino.py @@ -221,7 +221,8 @@ def forward_decoder(self, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, - dn_mask: Optional[Tensor] = None) -> Dict: + dn_mask: Optional[Tensor] = None, + **kwargs) -> Dict: """Forward with Transformer decoder. The forward procedure of the transformer is defined as: @@ -270,7 +271,8 @@ def forward_decoder(self, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, - reg_branches=self.bbox_head.reg_branches) + reg_branches=self.bbox_head.reg_branches, + **kwargs) if len(query) == self.num_queries: # NOTE: This is to make sure label_embeding can be involved to diff --git a/mmdet/models/detectors/grounding_dino.py b/mmdet/models/detectors/grounding_dino.py new file mode 100644 index 00000000000..b2495b91cd3 --- /dev/null +++ b/mmdet/models/detectors/grounding_dino.py @@ -0,0 +1,309 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Dict, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList +from ..layers import SinePositionalEncoding +from ..layers.transformer.grounding_dino_layers import ( + GroundingDinoTransformerDecoder, GroundingDinoTransformerEncoder) +from .dino import DINO +from .glip import (create_positive_map, create_positive_map_label_to_token, + run_ner) + + +@MODELS.register_module() +class GroundingDINO(DINO): + """Implementation of `Grounding DINO: Marrying DINO with Grounded Pre- + Training for Open-Set Object Detection. + + `_ + + Code is modified from the `official github repo + `_. + """ + + def __init__(self, language_model, *args, **kwargs) -> None: + + self.language_model_cfg = language_model + self._special_tokens = '. ' + super().__init__(*args, **kwargs) + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding( + **self.positional_encoding) + self.encoder = GroundingDinoTransformerEncoder(**self.encoder) + self.decoder = GroundingDinoTransformerDecoder(**self.decoder) + self.embed_dims = self.encoder.embed_dims + self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, \ + f'embed_dims should be exactly 2 times of num_feats. ' \ + f'Found {self.embed_dims} and {num_feats}.' + + self.level_embed = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims)) + self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) + self.memory_trans_norm = nn.LayerNorm(self.embed_dims) + + # text modules + self.language_model = MODELS.build(self.language_model_cfg) + self.text_feat_map = nn.Linear( + self.language_model.language_backbone.body.language_dim, + self.embed_dims, + bias=True) + nn.init.constant_(self.text_feat_map.bias.data, 0) + nn.init.xavier_uniform_(self.text_feat_map.weight.data) + + def get_tokens_and_prompts( + self, + original_caption: Union[str, list, tuple], + custom_entities: bool = False) -> Tuple[dict, str, list]: + """Get the tokens positive and prompts for the caption.""" + if isinstance(original_caption, (list, tuple)) or custom_entities: + if custom_entities and isinstance(original_caption, str): + if original_caption.endswith(self._special_tokens): + original_caption = original_caption.replace( + self._special_tokens, '') + original_caption = original_caption.split(self._special_tokens) + original_caption = list( + filter(lambda x: len(x) > 0, original_caption)) + + caption_string = '' + tokens_positive = [] + for idx, word in enumerate(original_caption): + tokens_positive.append( + [[len(caption_string), + len(caption_string) + len(word)]]) + caption_string += word + caption_string += self._special_tokens + # NOTE: Tokenizer in Grounding DINO is different from + # that in GLIP. The tokenizer in GLIP will pad the + # caption_string to max_length, while the tokenizer + # in Grounding DINO will not. + tokenized = self.language_model.tokenizer( + [caption_string], + padding='max_length' + if self.language_model.pad_to_max else 'longest', + return_tensors='pt') + entities = original_caption + else: + if original_caption.endswith(self._special_tokens): + original_caption = original_caption.replace( + self._special_tokens, '') + # NOTE: Tokenizer in Grounding DINO is different from + # that in GLIP. The tokenizer in GLIP will pad the + # caption_string to max_length, while the tokenizer + # in Grounding DINO will not. + tokenized = self.language_model.tokenizer( + [original_caption], + padding='max_length' + if self.language_model.pad_to_max else 'longest', + return_tensors='pt') + tokens_positive, noun_phrases = run_ner(original_caption) + entities = noun_phrases + caption_string = original_caption + + return tokenized, caption_string, tokens_positive, entities + + def get_positive_map(self, tokenized, tokens_positive): + positive_map = create_positive_map(tokenized, tokens_positive) + positive_map_label_to_token = create_positive_map_label_to_token( + positive_map, plus=1) + return positive_map_label_to_token, positive_map + + def get_tokens_positive_and_prompts( + self, + original_caption: Union[str, list, tuple], + custom_entities: bool = False) -> Tuple[dict, str, Tensor, list]: + """Get the tokens positive and prompts for the caption.""" + tokenized, caption_string, tokens_positive, entities = \ + self.get_tokens_and_prompts( + original_caption, custom_entities) + positive_map_label_to_token, positive_map = self.get_positive_map( + tokenized, tokens_positive) + return positive_map_label_to_token, caption_string, \ + positive_map, entities + + def forward_transformer( + self, + img_feats: Tuple[Tensor], + text_dict: Dict, + batch_data_samples: OptSampleList = None, + ) -> Dict: + encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( + img_feats, batch_data_samples) + + encoder_outputs_dict = self.forward_encoder( + **encoder_inputs_dict, text_dict=text_dict) + + tmp_dec_in, head_inputs_dict = self.pre_decoder( + **encoder_outputs_dict, batch_data_samples=batch_data_samples) + decoder_inputs_dict.update(tmp_dec_in) + + decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) + head_inputs_dict.update(decoder_outputs_dict) + return head_inputs_dict + + def forward_encoder(self, feat: Tensor, feat_mask: Tensor, + feat_pos: Tensor, spatial_shapes: Tensor, + level_start_index: Tensor, valid_ratios: Tensor, + text_dict: Dict) -> Dict: + text_token_mask = text_dict['text_token_mask'] + memory, memory_text = self.encoder( + query=feat, + query_pos=feat_pos, + key_padding_mask=feat_mask, # for self_attn + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + # for text encoder + memory_text=text_dict['embedded'], + text_attention_mask=~text_token_mask, + position_ids=text_dict['position_ids'], + text_self_attention_masks=text_dict['masks']) + encoder_outputs_dict = dict( + memory=memory, + memory_mask=feat_mask, + spatial_shapes=spatial_shapes, + memory_text=memory_text, + text_token_mask=text_token_mask) + return encoder_outputs_dict + + def pre_decoder( + self, + memory: Tensor, + memory_mask: Tensor, + spatial_shapes: Tensor, + memory_text: Tensor, + text_token_mask: Tensor, + batch_data_samples: OptSampleList = None, + ) -> Tuple[Dict]: + bs, _, c = memory.shape + + output_memory, output_proposals = self.gen_encoder_output_proposals( + memory, memory_mask, spatial_shapes) + + enc_outputs_class = self.bbox_head.cls_branches[ + self.decoder.num_layers](output_memory, memory_text, + text_token_mask) + cls_out_features = self.bbox_head.cls_branches[ + self.decoder.num_layers].max_text_len + enc_outputs_coord_unact = self.bbox_head.reg_branches[ + self.decoder.num_layers](output_memory) + output_proposals + + # NOTE The DINO selects top-k proposals according to scores of + # multi-class classification, while DeformDETR, where the input + # is `enc_outputs_class[..., 0]` selects according to scores of + # binary classification. + topk_indices = torch.topk( + enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1] + + topk_score = torch.gather( + enc_outputs_class, 1, + topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, + topk_indices.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords = topk_coords_unact.sigmoid() + topk_coords_unact = topk_coords_unact.detach() + + query = self.query_embedding.weight[:, None, :] + query = query.repeat(1, bs, 1).transpose(0, 1) + if self.training: + dn_label_query, dn_bbox_query, dn_mask, dn_meta = \ + self.dn_query_generator(batch_data_samples) + query = torch.cat([dn_label_query, query], dim=1) + reference_points = torch.cat([dn_bbox_query, topk_coords_unact], + dim=1) + else: + reference_points = topk_coords_unact + dn_mask, dn_meta = None, None + reference_points = reference_points.sigmoid() + + decoder_inputs_dict = dict( + query=query, + memory=memory, + reference_points=reference_points, + dn_mask=dn_mask, + memory_text=memory_text, + text_attention_mask=~text_token_mask, + ) + # NOTE DINO calculates encoder losses on scores and coordinates + # of selected top-k encoder queries, while DeformDETR is of all + # encoder queries. + head_inputs_dict = dict( + enc_outputs_class=topk_score, + enc_outputs_coord=topk_coords, + dn_meta=dn_meta) if self.training else dict() + # append text_feats to head_inputs_dict + head_inputs_dict['memory_text'] = memory_text + head_inputs_dict['text_token_mask'] = text_token_mask + return decoder_inputs_dict, head_inputs_dict + + def predict(self, batch_inputs, batch_data_samples, rescale: bool = True): + text_prompts = [ + data_samples.text for data_samples in batch_data_samples + ] + if 'custom_entities' in batch_data_samples[0]: + # Assuming that the `custom_entities` flag + # inside a batch is always the same. For single image inference + custom_entities = batch_data_samples[0].custom_entities + else: + custom_entities = False + if len(text_prompts) == 1: + # All the text prompts are the same, + # so there is no need to calculate them multiple times. + _positive_maps_and_prompts = [ + self.get_tokens_positive_and_prompts(text_prompts[0], + custom_entities) + ] * len(batch_inputs) + else: + _positive_maps_and_prompts = [ + self.get_tokens_positive_and_prompts(text_prompt, + custom_entities) + for text_prompt in text_prompts + ] + token_positive_maps, text_prompts, _, entities = zip( + *_positive_maps_and_prompts) + # extract text feats + text_dict = self.language_model(list(text_prompts)) + # text feature map layer + if self.text_feat_map is not None: + text_dict['embedded'] = self.text_feat_map(text_dict['embedded']) + + for i, data_samples in enumerate(batch_data_samples): + data_samples.token_positive_map = token_positive_maps[i] + + # image feature extraction + visual_feats = self.extract_feat(batch_inputs) + + head_inputs_dict = self.forward_transformer(visual_feats, text_dict, + batch_data_samples) + results_list = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples) + for data_sample, pred_instances, entity in zip(batch_data_samples, + results_list, entities): + if len(pred_instances) > 0: + label_names = [] + for labels in pred_instances.labels: + if labels >= len(entity): + warnings.warn( + 'The unexpected output indicates an issue with ' + 'named entity recognition. You can try ' + 'setting custom_entities=True and running ' + 'again to see if it helps.') + label_names.append('unobject') + else: + label_names.append(entity[labels]) + # for visualization + pred_instances.label_names = label_names + data_sample.pred_instances = pred_instances + return batch_data_samples diff --git a/mmdet/models/language_models/bert.py b/mmdet/models/language_models/bert.py index 86a4dc8d5d1..3a911bbc2f4 100644 --- a/mmdet/models/language_models/bert.py +++ b/mmdet/models/language_models/bert.py @@ -16,20 +16,72 @@ from mmdet.registry import MODELS +def generate_masks_with_special_tokens_and_transfer_map( + tokenized, special_tokens_list): + """Generate attention mask between each pair of special tokens + Args: + input_ids (torch.Tensor): input ids. Shape: [bs, num_token] + special_tokens_mask (list): special tokens mask. + Returns: + torch.Tensor: attention mask between each special tokens. + """ + input_ids = tokenized['input_ids'] + bs, num_token = input_ids.shape + # special_tokens_mask: + # bs, num_token. 1 for special tokens. 0 for normal tokens + special_tokens_mask = torch.zeros((bs, num_token), + device=input_ids.device).bool() + + for special_token in special_tokens_list: + special_tokens_mask |= input_ids == special_token + + # idxs: each row is a list of indices of special tokens + idxs = torch.nonzero(special_tokens_mask) + + # generate attention mask and positional ids + attention_mask = ( + torch.eye(num_token, + device=input_ids.device).bool().unsqueeze(0).repeat( + bs, 1, 1)) + position_ids = torch.zeros((bs, num_token), device=input_ids.device) + previous_col = 0 + for i in range(idxs.shape[0]): + row, col = idxs[i] + if (col == 0) or (col == num_token - 1): + attention_mask[row, col, col] = True + position_ids[row, col] = 0 + else: + attention_mask[row, previous_col + 1:col + 1, + previous_col + 1:col + 1] = True + position_ids[row, previous_col + 1:col + 1] = torch.arange( + 0, col - previous_col, device=input_ids.device) + previous_col = col + + return attention_mask, position_ids.to(torch.long) + + @MODELS.register_module() class BertModel(BaseModel): """BERT model for language embedding only encoder. Args: - name (str): name of the pretrained BERT model from HuggingFace. - Defaults to bert-base-uncased. - max_tokens (int): maximum number of tokens to be used for BERT. - Defaults to 256. - pad_to_max (bool): whether to pad the tokens to max_tokens. + name (str, optional): name of the pretrained BERT model from + HuggingFace. Defaults to bert-base-uncased. + max_tokens (int, optional): maximum number of tokens to be + used for BERT. Defaults to 256. + pad_to_max (bool, optional): whether to pad the tokens to max_tokens. Defaults to True. - num_layers_of_embedded (int): number of layers of the embedded model. - Defaults to 1. - use_checkpoint (bool): whether to use gradient checkpointing. + use_sub_sentence_represent (bool, optional): whether to use sub + sentence represent introduced in `Grounding DINO + `. Defaults to False. + special_tokens_list (list, optional): special tokens used to split + subsentence. It cannot be None when `use_sub_sentence_represent` + is True. Defaults to None. + add_pooling_layer (bool, optional): whether to adding pooling + layer in bert encoder. Defaults to False. + num_layers_of_embedded (int, optional): number of layers of + the embedded model. Defaults to 1. + use_checkpoint (bool, optional): whether to use gradient checkpointing. Defaults to False. """ @@ -37,9 +89,13 @@ def __init__(self, name: str = 'bert-base-uncased', max_tokens: int = 256, pad_to_max: bool = True, + use_sub_sentence_represent: bool = False, + special_tokens_list: list = None, + add_pooling_layer: bool = False, num_layers_of_embedded: int = 1, use_checkpoint: bool = False, **kwargs) -> None: + super().__init__(**kwargs) self.max_tokens = max_tokens self.pad_to_max = pad_to_max @@ -54,9 +110,19 @@ def __init__(self, OrderedDict([('body', BertEncoder( name, + add_pooling_layer=add_pooling_layer, num_layers_of_embedded=num_layers_of_embedded, use_checkpoint=use_checkpoint))])) + self.use_sub_sentence_represent = use_sub_sentence_represent + if self.use_sub_sentence_represent: + assert special_tokens_list is not None, \ + 'special_tokens should not be None \ + if use_sub_sentence_represent is True' + + self.special_tokens = self.tokenizer.convert_tokens_to_ids( + special_tokens_list) + def forward(self, captions: Sequence[str], **kwargs) -> dict: """Forward function.""" device = next(self.language_backbone.parameters()).device @@ -67,12 +133,29 @@ def forward(self, captions: Sequence[str], **kwargs) -> dict: return_special_tokens_mask=True, return_tensors='pt', truncation=True).to(device) + input_ids = tokenized.input_ids + if self.use_sub_sentence_represent: + attention_mask, position_ids = \ + generate_masks_with_special_tokens_and_transfer_map( + tokenized, self.special_tokens) + token_type_ids = tokenized['token_type_ids'] + + else: + attention_mask = tokenized.attention_mask + position_ids = None + token_type_ids = None tokenizer_input = { - 'input_ids': tokenized.input_ids, - 'attention_mask': tokenized.attention_mask + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'position_ids': position_ids, + 'token_type_ids': token_type_ids } language_dict_features = self.language_backbone(tokenizer_input) + if self.use_sub_sentence_represent: + language_dict_features['position_ids'] = position_ids + language_dict_features[ + 'text_token_mask'] = tokenized.attention_mask.bool() return language_dict_features @@ -82,6 +165,7 @@ class BertEncoder(nn.Module): Args: name (str): name of the pretrained BERT model from HuggingFace. Defaults to bert-base-uncased. + add_pooling_layer (bool): whether to add a pooling layer. num_layers_of_embedded (int): number of layers of the embedded model. Defaults to 1. use_checkpoint (bool): whether to use gradient checkpointing. @@ -90,6 +174,7 @@ class BertEncoder(nn.Module): def __init__(self, name: str, + add_pooling_layer: bool = False, num_layers_of_embedded: int = 1, use_checkpoint: bool = False): super().__init__() @@ -101,7 +186,7 @@ def __init__(self, config.gradient_checkpointing = use_checkpoint # only encoder self.model = HFBertModel.from_pretrained( - name, add_pooling_layer=False, config=config) + name, add_pooling_layer=add_pooling_layer, config=config) self.language_dim = config.hidden_size self.num_layers_of_embedded = num_layers_of_embedded @@ -111,6 +196,8 @@ def forward(self, x) -> dict: outputs = self.model( input_ids=x['input_ids'], attention_mask=mask, + position_ids=x['position_ids'], + token_type_ids=x['token_type_ids'], output_hidden_states=True, ) @@ -120,7 +207,10 @@ def forward(self, x) -> dict: 1).mean(1) # language embedding has shape [len(phrase), seq_len, language_dim] features = features / self.num_layers_of_embedded - embedded = features * mask.unsqueeze(-1).float() + if mask.dim() == 2: + embedded = features * mask.unsqueeze(-1).float() + else: + embedded = features results = { 'embedded': embedded, diff --git a/mmdet/models/layers/transformer/__init__.py b/mmdet/models/layers/transformer/__init__.py index 3465ef3d1a7..839d9364126 100644 --- a/mmdet/models/layers/transformer/__init__.py +++ b/mmdet/models/layers/transformer/__init__.py @@ -12,6 +12,9 @@ from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, DetrTransformerEncoder, DetrTransformerEncoderLayer) from .dino_layers import CdnQueryGenerator, DinoTransformerDecoder +from .grounding_dino_layers import (GroundingDinoTransformerDecoder, + GroundingDinoTransformerDecoderLayer, + GroundingDinoTransformerEncoder) from .mask2former_layers import (Mask2FormerTransformerDecoder, Mask2FormerTransformerDecoderLayer, Mask2FormerTransformerEncoder) @@ -32,5 +35,7 @@ 'DDQTransformerDecoder', 'ConditionalDetrTransformerDecoder', 'ConditionalDetrTransformerDecoderLayer', 'DinoTransformerDecoder', 'CdnQueryGenerator', 'Mask2FormerTransformerEncoder', - 'Mask2FormerTransformerDecoderLayer', 'Mask2FormerTransformerDecoder' + 'Mask2FormerTransformerDecoderLayer', 'Mask2FormerTransformerDecoder', + 'GroundingDinoTransformerDecoderLayer', 'GroundingDinoTransformerEncoder', + 'GroundingDinoTransformerDecoder' ] diff --git a/mmdet/models/layers/transformer/grounding_dino_layers.py b/mmdet/models/layers/transformer/grounding_dino_layers.py new file mode 100644 index 00000000000..04de47288b3 --- /dev/null +++ b/mmdet/models/layers/transformer/grounding_dino_layers.py @@ -0,0 +1,255 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmcv.ops import MultiScaleDeformableAttention +from mmengine.model import ModuleList +from torch import Tensor + +from mmdet.models.utils.vlfuse_helper import SingleScaleBiAttentionBlock +from mmdet.utils import ConfigType, OptConfigType +from .deformable_detr_layers import (DeformableDetrTransformerDecoderLayer, + DeformableDetrTransformerEncoder, + DeformableDetrTransformerEncoderLayer) +from .detr_layers import DetrTransformerEncoderLayer +from .dino_layers import DinoTransformerDecoder +from .utils import MLP, get_text_sine_pos_embed + + +class GroundingDinoTransformerDecoderLayer( + DeformableDetrTransformerDecoderLayer): + + def __init__(self, + cross_attn_text_cfg: OptConfigType = dict( + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + **kwargs) -> None: + """Decoder layer of Deformable DETR.""" + self.cross_attn_text_cfg = cross_attn_text_cfg + if 'batch_first' not in self.cross_attn_text_cfg: + self.cross_attn_text_cfg['batch_first'] = True + super().__init__(**kwargs) + + def _init_layers(self) -> None: + """Initialize self_attn, cross-attn, ffn, and norms.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.cross_attn_text = MultiheadAttention(**self.cross_attn_text_cfg) + self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(4) + ] + self.norms = ModuleList(norms_list) + + def forward(self, + query: Tensor, + key: Tensor = None, + value: Tensor = None, + query_pos: Tensor = None, + key_pos: Tensor = None, + self_attn_mask: Tensor = None, + cross_attn_mask: Tensor = None, + key_padding_mask: Tensor = None, + memory_text: Tensor = None, + text_attention_mask: Tensor = None, + **kwargs) -> Tensor: + """Implements decoder layer in Grounding DINO transformer. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + key (Tensor, optional): The input key, has shape (bs, num_keys, + dim). If `None`, the `query` will be used. Defaults to `None`. + value (Tensor, optional): The input value, has the same shape as + `key`, as in `nn.MultiheadAttention.forward`. If `None`, the + `key` will be used. Defaults to `None`. + query_pos (Tensor, optional): The positional encoding for `query`, + has the same shape as `query`. If not `None`, it will be added + to `query` before forward function. Defaults to `None`. + key_pos (Tensor, optional): The positional encoding for `key`, has + the same shape as `key`. If not `None`, it will be added to + `key` before forward function. If None, and `query_pos` has the + same shape as `key`, then `query_pos` will be used for + `key_pos`. Defaults to None. + self_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + cross_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor, optional): The `key_padding_mask` of + `self_attn` input. ByteTensor, has shape (bs, num_value). + Defaults to None. + memory_text (Tensor): Memory text. It has shape (bs, len_text, + text_embed_dims). + text_token_mask (Tensor): Text token mask. It has shape (bs, + len_text). + + Returns: + Tensor: forwarded results, has shape (bs, num_queries, dim). + """ + # self attention + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=self_attn_mask, + **kwargs) + query = self.norms[0](query) + # cross attention between query and text + query = self.cross_attn_text( + query=query, + query_pos=query_pos, + key=memory_text, + value=memory_text, + key_padding_mask=text_attention_mask) + query = self.norms[1](query) + # cross attention between query and image + query = self.cross_attn( + query=query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=cross_attn_mask, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[2](query) + query = self.ffn(query) + query = self.norms[3](query) + + return query + + +class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder): + + def __init__(self, text_layer_cfg: ConfigType, + fusion_layer_cfg: ConfigType, **kwargs) -> None: + self.text_layer_cfg = text_layer_cfg + self.fusion_layer_cfg = fusion_layer_cfg + super().__init__(**kwargs) + + def _init_layers(self) -> None: + """Initialize encoder layers.""" + self.layers = ModuleList([ + DeformableDetrTransformerEncoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.text_layers = ModuleList([ + DetrTransformerEncoderLayer(**self.text_layer_cfg) + for _ in range(self.num_layers) + ]) + self.fusion_layers = ModuleList([ + SingleScaleBiAttentionBlock(**self.fusion_layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + + def forward(self, + query: Tensor, + query_pos: Tensor, + key_padding_mask: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor, + memory_text: Tensor = None, + text_attention_mask: Tensor = None, + pos_text: Tensor = None, + text_self_attention_masks: Tensor = None, + position_ids: Tensor = None): + """Forward function of Transformer encoder. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + query_pos (Tensor): The positional encoding for query, has shape + (bs, num_queries, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (bs, num_queries). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + memory_text (Tensor, optional): Memory text. It has shape (bs, + len_text, text_embed_dims). + text_attention_mask (Tensor, optional): Text token mask. It has + shape (bs,len_text). + pos_text (Tensor, optional): The positional encoding for text. + Defaults to None. + text_self_attention_masks (Tensor, optional): Text self attention + mask. Defaults to None. + position_ids (Tensor, optional): Text position ids. + Defaults to None. + """ + output = query + reference_points = self.get_encoder_reference_points( + spatial_shapes, valid_ratios, device=query.device) + if self.text_layers: + # generate pos_text + bs, n_text, _ = memory_text.shape + if pos_text is None and position_ids is None: + pos_text = ( + torch.arange(n_text, + device=memory_text.device).float().unsqueeze( + 0).unsqueeze(-1).repeat(bs, 1, 1)) + pos_text = get_text_sine_pos_embed( + pos_text, num_pos_feats=256, exchange_xy=False) + if position_ids is not None: + pos_text = get_text_sine_pos_embed( + position_ids[..., None], + num_pos_feats=256, + exchange_xy=False) + + # main process + for layer_id, layer in enumerate(self.layers): + if self.fusion_layers: + output, memory_text = self.fusion_layers[layer_id]( + visual_feature=output, + lang_feature=memory_text, + attention_mask_v=key_padding_mask, + attention_mask_l=text_attention_mask, + ) + if self.text_layers: + text_num_heads = self.text_layers[ + layer_id].self_attn_cfg.num_heads + memory_text = self.text_layers[layer_id]( + query=memory_text, + query_pos=(pos_text if pos_text is not None else None), + attn_mask=~text_self_attention_masks.repeat( + text_num_heads, 1, 1), # note we use ~ for mask here + key_padding_mask=None, + ) + output = layer( + query=output, + query_pos=query_pos, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + key_padding_mask=key_padding_mask) + return output, memory_text + + +class GroundingDinoTransformerDecoder(DinoTransformerDecoder): + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + GroundingDinoTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + if self.post_norm_cfg is not None: + raise ValueError('There is not post_norm in ' + f'{self._get_name()}') + self.ref_point_head = MLP(self.embed_dims * 2, self.embed_dims, + self.embed_dims, 2) + self.norm = nn.LayerNorm(self.embed_dims) diff --git a/mmdet/models/layers/transformer/utils.py b/mmdet/models/layers/transformer/utils.py index 3ba8a824a24..6e43a172ca7 100644 --- a/mmdet/models/layers/transformer/utils.py +++ b/mmdet/models/layers/transformer/utils.py @@ -874,3 +874,42 @@ def forward(self, param_feature: Tensor, input_feature: Tensor) -> Tensor: features = self.activation(features) return features + + +def get_text_sine_pos_embed( + pos_tensor: torch.Tensor, + num_pos_feats: int = 128, + temperature: int = 10000, + exchange_xy: bool = True, +): + """generate sine position embedding from a position tensor + Args: + pos_tensor (torch.Tensor): shape: [..., n]. + num_pos_feats (int): projected shape for each float in the tensor. + temperature (int): temperature in the sine/cosine function. + exchange_xy (bool, optional): exchange pos x and pos y. For example, + input tensor is [x,y], the results will be [pos(y), pos(x)]. + Defaults to True. + Returns: + pos_embed (torch.Tensor): shape: [..., n*num_pos_feats]. + """ + scale = 2 * math.pi + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=pos_tensor.device) + dim_t = temperature**(2 * torch.div(dim_t, 2, rounding_mode='floor') / + num_pos_feats) + + def sine_func(x: torch.Tensor): + sin_x = x * scale / dim_t + sin_x = torch.stack((sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), + dim=3).flatten(2) + return sin_x + + pos_res = [ + sine_func(x) + for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1) + ] + if exchange_xy: + pos_res[0], pos_res[1] = pos_res[1], pos_res[0] + pos_res = torch.cat(pos_res, dim=-1) + return pos_res diff --git a/mmdet/models/necks/channel_mapper.py b/mmdet/models/necks/channel_mapper.py index 9700a2b3e72..74293618f2b 100644 --- a/mmdet/models/necks/channel_mapper.py +++ b/mmdet/models/necks/channel_mapper.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Tuple +from typing import List, Tuple, Union import torch.nn as nn from mmcv.cnn import ConvModule @@ -27,6 +27,9 @@ class ChannelMapper(BaseModule): normalization layer. Default: None. act_cfg (:obj:`ConfigDict` or dict, optional): Config dict for activation layer in ConvModule. Default: dict(type='ReLU'). + bias (bool | str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise + False. Default: "auto". num_outs (int, optional): Number of output feature maps. There would be extra_convs when num_outs larger than the length of in_channels. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or dict], @@ -55,6 +58,7 @@ def __init__( conv_cfg: OptConfigType = None, norm_cfg: OptConfigType = None, act_cfg: OptConfigType = dict(type='ReLU'), + bias: Union[bool, str] = 'auto', num_outs: int = None, init_cfg: OptMultiConfig = dict( type='Xavier', layer='Conv2d', distribution='uniform') @@ -74,7 +78,8 @@ def __init__( padding=(kernel_size - 1) // 2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + bias=bias)) if num_outs > len(in_channels): self.extra_convs = nn.ModuleList() for i in range(len(in_channels), num_outs): @@ -91,7 +96,8 @@ def __init__( padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, - act_cfg=act_cfg)) + act_cfg=act_cfg, + bias=bias)) def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]: """Forward function.""" diff --git a/mmdet/models/utils/vlfuse_helper.py b/mmdet/models/utils/vlfuse_helper.py index f6112bf5051..76b54de317c 100644 --- a/mmdet/models/utils/vlfuse_helper.py +++ b/mmdet/models/utils/vlfuse_helper.py @@ -94,7 +94,7 @@ def __init__(self, self.l_dim = l_dim assert ( - self.head_dim * self.num_heads == self.embed_dim + self.head_dim * self.num_heads == self.embed_dim ), 'embed_dim must be divisible by num_heads ' \ f'(got `embed_dim`: {self.embed_dim} ' \ f'and `num_heads`: {self.num_heads}).' @@ -134,10 +134,11 @@ def _reset_parameters(self): self.out_l_proj.bias.data.fill_(0) def forward( - self, - vision: Tensor, - lang: Tensor, - attention_mask_l: Optional[Tensor] = None + self, + vision: Tensor, + lang: Tensor, + attention_mask_v: Optional[Tensor] = None, + attention_mask_l: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: bsz, tgt_len, _ = vision.size() @@ -183,6 +184,13 @@ def forward( # Do not increase 50000, data type half has quite limited range attn_weights_l = torch.clamp(attn_weights_l, max=MAX_CLAMP_VALUE) + if attention_mask_v is not None: + attention_mask_v = ( + attention_mask_v[:, None, + None, :].repeat(1, self.num_heads, 1, + 1).flatten(0, 1)) + attn_weights_l.masked_fill_(attention_mask_v, float('-inf')) + attn_weights_l = attn_weights_l.softmax(dim=-1) if attention_mask_l is not None: @@ -324,10 +332,11 @@ def forward(self, return fvfs[0], fvfs[1], fvfs[2], fvfs[3], fvfs[4], new_lang_feature def single_attention_call( - self, - visual: Tensor, - lang: Tensor, - attention_mask_l: Optional[Tensor] = None + self, + visual: Tensor, + lang: Tensor, + attention_mask_v: Optional[Tensor] = None, + attention_mask_l: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """Perform a single attention call between the visual and language inputs. @@ -335,6 +344,8 @@ def single_attention_call( Args: visual (Tensor): The visual input tensor. lang (Tensor): The language input tensor. + attention_mask_v (Optional[Tensor]): + An optional attention mask tensor for the visual input. attention_mask_l (Optional[Tensor]): An optional attention mask tensor for the language input. @@ -345,13 +356,50 @@ def single_attention_call( visual = self.layer_norm_v(visual) lang = self.layer_norm_l(lang) delta_v, delta_l = self.attn( - visual, lang, attention_mask_l=attention_mask_l) + visual, + lang, + attention_mask_v=attention_mask_v, + attention_mask_l=attention_mask_l) # visual, lang = visual + delta_v, l + delta_l visual = visual + self.drop_path(self.gamma_v * delta_v) lang = lang + self.drop_path(self.gamma_l * delta_l) return visual, lang +class SingleScaleBiAttentionBlock(BiAttentionBlock): + """This is a single-scale implementation of `BiAttentionBlock`. + + The only differenece between it and `BiAttentionBlock` is that the + `forward` function of `SingleScaleBiAttentionBlock` only accepts a single + flatten visual feature map, while the `forward` function in + `BiAttentionBlock` accepts multiple visual feature maps. + """ + + def forward(self, + visual_feature: Tensor, + lang_feature: Tensor, + attention_mask_v=None, + attention_mask_l=None): + """Single-scale forward pass. + + Args: + visual_feature (Tensor): The visual input tensor. Tensor of + shape (bs, patch_len, ch). + lang_feature (Tensor): The language input tensor. Tensor of + shape (bs, text_len, ch). + attention_mask_v (_type_, optional): Visual feature attention + mask. Defaults to None. + attention_mask_l (_type_, optional): Language feature attention + mask.Defaults to None. + """ + new_v, new_lang_feature = self.single_attention_call( + visual_feature, + lang_feature, + attention_mask_v=attention_mask_v, + attention_mask_l=attention_mask_l) + return new_v, new_lang_feature + + class VLFuse(nn.Module): """Early Fusion Module. diff --git a/model-index.yml b/model-index.yml index cbb379950e0..f1704c042cd 100644 --- a/model-index.yml +++ b/model-index.yml @@ -98,3 +98,4 @@ Import: - configs/masktrack_rcnn/metafile.yml - configs/glip/metafile.yml - configs/ddq/metafile.yml + - configs/grounding_dino/metafile.yml diff --git a/tools/model_converters/groundingdino_to_mmdet.py b/tools/model_converters/groundingdino_to_mmdet.py new file mode 100644 index 00000000000..b5896731d7b --- /dev/null +++ b/tools/model_converters/groundingdino_to_mmdet.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import subprocess +from collections import OrderedDict + +import torch +from mmengine.runner import CheckpointLoader + + +def correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, 2).reshape(out_channel, in_channel) + return x + + +def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + +def convert(ckpt): + new_ckpt = OrderedDict() + + for k, v in list(ckpt.items()): + new_v = v + # + if 'module' not in k: + # NOTE: swin-b has no module prefix and swin-t has module prefix + k = 'module.' + k + if 'module.bbox_embed' in k: + # NOTE: bbox_embed name is swin-b is different from swin-t + k = k.replace('module.bbox_embed', + 'module.transformer.decoder.bbox_embed') + + if 'module.backbone.0' in k: + new_k = k.replace('module.backbone.0', 'backbone') + if 'patch_embed.proj' in new_k: + new_k = new_k.replace('patch_embed.proj', + 'patch_embed.projection') + elif 'pos_drop' in new_k: + new_k = new_k.replace('pos_drop', 'drop_after_pos') + + if 'layers' in new_k: + new_k = new_k.replace('layers', 'stages') + if 'mlp.fc1' in new_k: + new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0') + elif 'mlp.fc2' in new_k: + new_k = new_k.replace('mlp.fc2', 'ffn.layers.1') + elif 'attn' in new_k: + new_k = new_k.replace('attn', 'attn.w_msa') + + if 'downsample' in k: + if 'reduction.' in k: + new_v = correct_unfold_reduction_order(v) + elif 'norm.' in k: + new_v = correct_unfold_norm_order(v) + + elif 'module.bert' in k: + new_k = k.replace('module.bert', + 'language_model.language_backbone.body.model') + # new_k = k.replace('module.bert', 'bert') + + elif 'module.feat_map' in k: + new_k = k.replace('module.feat_map', 'text_feat_map') + + elif 'module.input_proj' in k: + new_k = k.replace('module.input_proj', 'neck.convs') + if 'neck.convs.3' in new_k: + # extra convs for 4th scale + new_k = new_k.replace('neck.convs.3', 'neck.extra_convs.0') + if '0.weight' in new_k: + # 0.weight -> conv.weight + new_k = new_k.replace('0.weight', 'conv.weight') + if '0.bias' in new_k: + # 0.bias -> conv.bias + new_k = new_k.replace('0.bias', 'conv.bias') + if '1.weight' in new_k: + # 1.weight -> gn.weight + new_k = new_k.replace('1.weight', 'gn.weight') + if '1.bias' in new_k: + # 1.bias -> gn.bias + new_k = new_k.replace('1.bias', 'gn.bias') + + elif 'module.transformer.level_embed' in k: + # module.transformer.level_embed -> level_embed + new_k = k.replace('module.transformer.level_embed', 'level_embed') + + elif 'module.transformer.encoder' in k: + # if '.layers' in k: + new_k = k.replace('module.transformer.encoder', 'encoder') + if 'norm1' in new_k: + new_k = new_k.replace('norm1', 'norms.0') + if 'norm2' in new_k: + new_k = new_k.replace('norm2', 'norms.1') + if 'norm3' in new_k: + new_k = new_k.replace('norm3', 'norms.2') + if 'linear1' in new_k: + new_k = new_k.replace('linear1', 'ffn.layers.0.0') + if 'linear2' in new_k: + new_k = new_k.replace('linear2', 'ffn.layers.1') + + if 'text_layers' in new_k and 'self_attn' in new_k: + new_k = new_k.replace('self_attn', 'self_attn.attn') + + elif 'module.transformer.enc_output' in k: + if 'module.transformer.enc_output' in k and 'norm' not in k: + new_k = k.replace('module.transformer.enc_output', + 'memory_trans_fc') + if 'module.transformer.enc_output_norm' in k: + new_k = k.replace('module.transformer.enc_output_norm', + 'memory_trans_norm') + + elif 'module.transformer.enc_out_bbox_embed.layers' in k: + # ugly version + if 'module.transformer.enc_out_bbox_embed.layers.0' in k: + new_k = k.replace( + 'module.transformer.enc_out_bbox_embed.layers.0', + 'bbox_head.reg_branches.6.0') + if 'module.transformer.enc_out_bbox_embed.layers.1' in k: + new_k = k.replace( + 'module.transformer.enc_out_bbox_embed.layers.1', + 'bbox_head.reg_branches.6.2') + if 'module.transformer.enc_out_bbox_embed.layers.2' in k: + new_k = k.replace( + 'module.transformer.enc_out_bbox_embed.layers.2', + 'bbox_head.reg_branches.6.4') + + elif 'module.transformer.tgt_embed' in k: + new_k = k.replace('module.transformer.tgt_embed', + 'query_embedding') + + elif 'module.transformer.decoder' in k: + new_k = k.replace('module.transformer.decoder', 'decoder') + if 'norm1' in new_k: + # norm1 in official GroundingDINO is the third norm in decoder + new_k = new_k.replace('norm1', 'norms.2') + if 'catext_norm' in new_k: + # catext_norm in official GroundingDINO is the + # second norm in decoder + new_k = new_k.replace('catext_norm', 'norms.1') + if 'norm2' in new_k: + # norm2 in official GroundingDINO is the first norm in decoder + new_k = new_k.replace('norm2', 'norms.0') + if 'norm3' in new_k: + new_k = new_k.replace('norm3', 'norms.3') + if 'ca_text' in new_k: + new_k = new_k.replace('ca_text', 'cross_attn_text') + if 'in_proj_weight' in new_k: + new_k = new_k.replace('in_proj_weight', + 'attn.in_proj_weight') + if 'in_proj_bias' in new_k: + new_k = new_k.replace('in_proj_bias', 'attn.in_proj_bias') + if 'out_proj.weight' in new_k: + new_k = new_k.replace('out_proj.weight', + 'attn.out_proj.weight') + if 'out_proj.bias' in new_k: + new_k = new_k.replace('out_proj.bias', + 'attn.out_proj.bias') + if 'linear1' in new_k: + new_k = new_k.replace('linear1', 'ffn.layers.0.0') + if 'linear2' in new_k: + new_k = new_k.replace('linear2', 'ffn.layers.1') + if 'self_attn' in new_k: + new_k = new_k.replace('self_attn', 'self_attn.attn') + if 'bbox_embed' in new_k: + reg_layer_id = int(new_k.split('.')[2]) + linear_id = int(new_k.split('.')[4]) + weight_or_bias = new_k.split('.')[-1] + new_k = 'bbox_head.reg_branches.' + \ + str(reg_layer_id)+'.'+str(2*linear_id)+'.'+weight_or_bias + + else: + print('skip:', k) + continue + + new_ckpt[new_k] = new_v + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys to mmdet style.') + parser.add_argument( + 'src', + default='groundingdino_swint_ogc.pth.pth', + help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument( + 'dst', + default='groundingdino_swint_ogc.pth_mmdet.pth', + help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + weight = convert(state_dict) + torch.save(weight, args.dst) + sha = subprocess.check_output(['sha256sum', args.dst]).decode() + final_file = args.dst.replace('.pth', '') + '-{}.pth'.format(sha[:8]) + subprocess.Popen(['mv', args.dst, final_file]) + print(f'Done!!, save to {final_file}') + + +if __name__ == '__main__': + main()