Skip to content

Commit

Permalink
update model save.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jan 15, 2024
1 parent 6d6e7e2 commit 79895e8
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions examples/macbert/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,16 @@ def main():
callbacks=[ckpt_callback]
)
# 进行训练
# train_loader中有数据
torch.autograd.set_detect_anomaly(True)
if 'train' in cfg.MODE and train_loader and len(train_loader) > 0:
if valid_loader and len(valid_loader) > 0:
trainer.fit(model, train_loader, valid_loader)
else:
trainer.fit(model, train_loader)
logger.info('train model done.')
# 进行测试的逻辑同训练
if 'test' in cfg.MODE and test_loader and len(test_loader) > 0:
trainer.test(model, test_loader)
# 模型转为transformers可加载
if ckpt_callback and len(ckpt_callback.best_model_path) > 0:
ckpt_path = ckpt_callback.best_model_path
Expand All @@ -111,17 +113,23 @@ def main():
if ckpt_path and os.path.exists(ckpt_path):
tokenizer.save_pretrained(cfg.OUTPUT_DIR)
if cfg.MODEL.NAME == 'softmaskedbert4csc':
m = SoftMaskedBert4Csc.load_from_checkpoint(ckpt_path)
model = SoftMaskedBert4Csc.load_from_checkpoint(
checkpoint_path=ckpt_path,
cfg=cfg,
map_location=device,
tokenizer=tokenizer
)
else:
m = MacBert4Csc.load_from_checkpoint(ckpt_path)
model = MacBert4Csc.load_from_checkpoint(
checkpoint_path=ckpt_path,
cfg=cfg,
map_location=device,
tokenizer=tokenizer
)
model.eval()
# 保存finetune训练后的模型文件pytorch_model.bin
pt_file = os.path.join(cfg.OUTPUT_DIR, 'pytorch_model.bin')
m.bert.save_pretrained(pt_file)
del m
# 进行测试的逻辑同训练
if 'test' in cfg.MODE and test_loader and len(test_loader) > 0:
trainer.test(model, test_loader)

model.bert.save_pretrained(pt_file)

if __name__ == '__main__':
main()

0 comments on commit 79895e8

Please sign in to comment.