Skip to content

Commit

Permalink
finetune MM-GDINO on ov_coco and ov_lvis (#11304)
Browse files Browse the repository at this point in the history
  • Loading branch information
xushilin1 committed Dec 22, 2023
1 parent dfffb99 commit 63713c9
Show file tree
Hide file tree
Showing 4 changed files with 547 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
_base_ = '../grounding_dino_swin-t_pretrain_obj365.py'

data_root = 'data/coco/'
base_classes = ('person', 'bicycle', 'car', 'motorcycle', 'train', 'truck',
'boat', 'bench', 'bird', 'horse', 'sheep', 'bear', 'zebra',
'giraffe', 'backpack', 'handbag', 'suitcase', 'frisbee',
'skis', 'kite', 'surfboard', 'bottle', 'fork', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'pizza', 'donut', 'chair', 'bed', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'microwave', 'oven', 'toaster',
'refrigerator', 'book', 'clock', 'vase', 'toothbrush')
novel_classes = ('airplane', 'bus', 'cat', 'dog', 'cow', 'elephant',
'umbrella', 'tie', 'snowboard', 'skateboard', 'cup', 'knife',
'cake', 'couch', 'keyboard', 'sink', 'scissors')
all_classes = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'kite', 'skateboard', 'surfboard',
'bottle', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana',
'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'pizza',
'donut', 'cake', 'chair', 'couch', 'bed', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'toothbrush')

train_metainfo = dict(classes=base_classes)
test_metainfo = dict(
classes=all_classes,
base_classes=base_classes,
novel_classes=novel_classes)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[
[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
]
]),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction', 'text',
'custom_entities'))
]

test_pipeline = [
dict(
type='LoadImageFromFile', backend_args=None,
imdecode_backend='pillow'),
dict(
type='FixScaleResize',
scale=(800, 1333),
keep_ratio=True,
backend='pillow'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'text', 'custom_entities',
'tokens_positive'))
]

train_dataloader = dict(
dataset=dict(
_delete_=True,
type='CocoDataset',
metainfo=train_metainfo,
data_root=data_root,
ann_file='zero-shot/instances_train2017_seen_2.json',
data_prefix=dict(img='train2017/'),
return_classes=True,
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='CocoDataset',
metainfo=test_metainfo,
data_root=data_root,
ann_file='zero-shot/instances_val2017_all_2.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=test_pipeline,
return_classes=True,
))
test_dataloader = val_dataloader

val_evaluator = dict(
type='OVCocoMetric',
ann_file=data_root + 'zero-shot/instances_val2017_all_2.json',
metric='bbox',
format_only=False)
test_evaluator = val_evaluator

optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.00005, weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'backbone': dict(lr_mult=0.1),
# 'language_model': dict(lr_mult=0),
}))

# learning policy
max_epochs = 12
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]
train_cfg = dict(max_epochs=max_epochs, val_interval=1)

default_hooks = dict(
checkpoint=dict(
max_keep_ckpts=1, save_best='coco/novel_ap50', rule='greater'))

load_from = 'epoch_30.pth'
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
_base_ = '../grounding_dino_swin-t_pretrain_obj365.py'

data_root = 'data/lvis/'

model = dict(test_cfg=dict(
max_per_img=300,
chunked_size=40,
))

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[
[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
# The radio of all image in train dataset < 7
# follow the original implement
scales=[(400, 4200), (500, 4200), (600, 4200)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
]
]),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
dict(
type='RandomSamplingNegPos',
tokenizer_name=_base_.lang_model_name,
num_sample_negative=85,
# change this
label_map_file='data/lvis/annotations/lvis_v1_label_map_norare.json',
max_tokens=256),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction', 'text',
'custom_entities', 'tokens_positive', 'dataset_mode'))
]

train_dataloader = dict(
dataset=dict(
_delete_=True,
type='ClassBalancedDataset',
oversample_thr=1e-3,
dataset=dict(
type='ODVGDataset',
data_root=data_root,
need_text=False,
label_map_file='annotations/lvis_v1_label_map_norare.json',
ann_file='annotations/lvis_v1_train_od_norare.json',
data_prefix=dict(img=''),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
return_classes=True,
pipeline=train_pipeline)))

val_dataloader = dict(
dataset=dict(
data_root=data_root,
type='LVISV1Dataset',
ann_file='annotations/lvis_v1_minival_inserted_image_name.json',
data_prefix=dict(img='')))
test_dataloader = val_dataloader

val_evaluator = dict(
_delete_=True,
type='LVISFixedAPMetric',
ann_file=data_root +
'annotations/lvis_v1_minival_inserted_image_name.json')
test_evaluator = val_evaluator

optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.00005, weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'backbone': dict(lr_mult=0.1),
# 'language_model': dict(lr_mult=0),
}))

# learning policy
max_epochs = 12
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]
train_cfg = dict(max_epochs=max_epochs, val_interval=3)

default_hooks = dict(
checkpoint=dict(
max_keep_ckpts=3, save_best='lvis_fixed_ap/AP', rule='greater'))

load_from = 'epoch_30.pth'
4 changes: 3 additions & 1 deletion mmdet/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .lvis_metric import LVISMetric
from .mot_challenge_metric import MOTChallengeMetric
from .openimages_metric import OpenImagesMetric
from .ov_coco_metric import OVCocoMetric
from .refexp_metric import RefExpMetric
from .refseg_metric import RefSegMetric
from .reid_metric import ReIDMetrics
Expand All @@ -29,5 +30,6 @@
'CocoOccludedSeparatedMetric', 'DumpDetResults', 'BaseVideoMetric',
'MOTChallengeMetric', 'CocoVideoMetric', 'ReIDMetrics', 'YouTubeVISMetric',
'COCOCaptionMetric', 'SemSegMetric', 'RefSegMetric', 'RefExpMetric',
'gRefCOCOMetric', 'DODCocoMetric', 'DumpODVGResults', 'Flickr30kMetric'
'gRefCOCOMetric', 'DODCocoMetric', 'DumpODVGResults', 'Flickr30kMetric',
'OVCocoMetric'
]
Loading

0 comments on commit 63713c9

Please sign in to comment.