Skip to content

Commit

Permalink
跑通+支持测速
Browse files Browse the repository at this point in the history
  • Loading branch information
lizexu123 committed Dec 29, 2023
1 parent 521157e commit 6a2ec0b
Showing 1 changed file with 84 additions and 10 deletions.
94 changes: 84 additions & 10 deletions example/auto_compression/detection/paddle_inference_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,15 @@ def argsparser():
parser.add_argument("--img_shape", type=int, default=640, help="input_size")
parser.add_argument(
'--include_nms',
type=bool,
default=True,
type=str,
default='True',
help="Whether include nms or not.")
# 是否用来测速
parser.add_argument(
'--speed',
type=str,
default='True',
help="if speed is True, it will print the inference time.")

return parser

Expand Down Expand Up @@ -238,9 +244,11 @@ def load_predictor(
config = Config(
os.path.join(model_dir, "model.pdmodel"),
os.path.join(model_dir, "model.pdiparams"))

config.enable_memory_optim()
if device == "GPU":
# initial GPU memory(M), device ID
config.enable_use_gpu(200, 0)
config.enable_use_gpu(1000, 0)
# optimize graph and fuse op
config.switch_ir_optim(True)
else:
Expand All @@ -260,7 +268,7 @@ def load_predictor(
}
if precision in precision_map.keys() and use_trt:
config.enable_tensorrt_engine(
workspace_size=(1 << 25) * batch_size,
workspace_size=(1 << 30) * batch_size,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[precision],
Expand Down Expand Up @@ -297,6 +305,7 @@ def predict_image(predictor,
img, scale_factor = image_preprocess(image_file, image_shape)
inputs = {}
inputs["image"] = img

if FLAGS.include_nms:
inputs['scale_factor'] = scale_factor
input_names = predictor.get_input_names()
Expand Down Expand Up @@ -354,6 +363,9 @@ def eval(predictor, val_loader, metric, rerun_flag=False):
input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
boxes_tensor = predictor.get_output_handle(output_names[0])
print("output_names:", output_names)
print("Number of outputs:", len(output_names))
print("FLAGS.include_nms:", FLAGS.include_nms)
if FLAGS.include_nms:
boxes_num = predictor.get_output_handle(output_names[1])
for batch_id, data in enumerate(val_loader):
Expand All @@ -374,27 +386,79 @@ def eval(predictor, val_loader, metric, rerun_flag=False):
time_min = min(time_min, timed)
time_max = max(time_max, timed)
predict_time += timed
if not FLAGS.include_nms:
# print("FLAGS.include_nms:", FLAGS.include_nms)
# print("FLAGS.speed:", FLAGS.speed)
# 如果include_nms为false且flags.speed为True,则走PPYOLOEPostProcess
if not FLAGS.include_nms and FLAGS.speed:
# print("nms为True的时候走了PPYOLOEPostProcess")
postprocess = PPYOLOEPostProcess(
score_threshold=0.3, nms_threshold=0.6)
res = postprocess(np_boxes, data_all['scale_factor'])
else:
#如果include_nms为false且flags.speed为False,则跳过
elif not FLAGS.include_nms and not FLAGS.speed:
continue
#如果include_nms,则直接返回
elif FLAGS.include_nms:
# print("nms为False的时候直接返回")
res = {'bbox': np_boxes, 'bbox_num': np_boxes_num}
metric.update(data_all, res)
if batch_id % 100 == 0:
print("Eval iter:", batch_id)
sys.stdout.flush()
metric.accumulate()
metric.log()
if not FLAGS.speed:
metric.log()
map_res = metric.get_results()
metric.reset()
time_avg = predict_time / sample_nums
print("[Benchmark]Inference time(ms): min={}, max={}, avg={}".format(
round(time_min * 1000, 2),
round(time_max * 1000, 1), round(time_avg * 1000, 1)))
print("[Benchmark] COCO mAP: {}".format(map_res["bbox"][0]))
if not FLAGS.speed:
print("[Benchmark] COCO mAP: {}".format(map_res["bbox"][0]))
sys.stdout.flush()

def inference_time(predictor, val_loader, metric, rerun_flag=False):
cpu_mems, gpu_mems = 0, 0
predict_time = 0.0
time_min = float("inf")
time_max = float("-inf")
sample_nums = len(val_loader)
input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
boxes_tensor = predictor.get_output_handle(output_names[0])
print("output_names:", output_names)
print("Number of outputs:", len(output_names))
print("FLAGS.include_nms:", FLAGS.include_nms)
if FLAGS.include_nms:
boxes_num = predictor.get_output_handle(output_names[1])

for batch_id, data in enumerate(val_loader):
data_all = {k: np.array(v) for k, v in data.items()}
for i, _ in enumerate(input_names):
input_tensor = predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(data_all[input_names[i]])
paddle.device.cuda.synchronize()
start_time = time.time()
predictor.run()
# np_boxes = boxes_tensor.copy_to_cpu()
if FLAGS.include_nms:
np_boxes_num = boxes_num.copy_to_cpu()
if rerun_flag:
return
end_time = time.time()
timed = end_time - start_time
time_min = min(time_min, timed)
time_max = max(time_max, timed)
predict_time += timed
# print("FLAGS.include_nms:", FLAGS.include_nms)
# print("FLAGS.speed:", FLAGS.speed)
# 如果include_nms为false且flags.speed为True,则走PPYOLOEPostProcess
time_avg = predict_time / sample_nums
print("[Benchmark]Inference time(ms): min={}, max={}, avg={}".format(
round(time_min * 1000, 2),
round(time_max * 1000, 1), round(time_avg * 1000, 1)))
sys.stdout.flush()

def main():
"""
Expand All @@ -421,7 +485,7 @@ def main():
repeats=repeats)
else:
reader_cfg = load_config(FLAGS.reader_config)

dataset = reader_cfg["EvalDataset"]
global val_loader
val_loader = create("EvalReader")(
Expand All @@ -432,7 +496,10 @@ def main():
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType="bbox")
eval(predictor, val_loader, metric, rerun_flag=rerun_flag)
if not FLAGS.speed:
eval(predictor, val_loader, metric, rerun_flag=rerun_flag)
else:
inference_time(predictor, val_loader, metric, rerun_flag=rerun_flag)

if rerun_flag:
print(
Expand All @@ -444,6 +511,13 @@ def main():
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
if FLAGS.include_nms=='True':
FLAGS.include_nms = True
else:
FLAGS.include_nms = False

print('**************main****************')
print(FLAGS)

# DataLoader need run on cpu
paddle.set_device("cpu")
Expand Down

0 comments on commit 6a2ec0b

Please sign in to comment.