-
Notifications
You must be signed in to change notification settings - Fork 43
/
main.py
78 lines (60 loc) · 2.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import json
import tensorflow as tf
import numpy as np
# from model.vae import VAWGAN
from analyzer import read, Tanhize
from util.wrapper import save, validate_log_dirs #, load, configure_gpu_settings, restore_global_step
# from trainer.vae import VAWGANTrainer
from importlib import import_module
args = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
'logdir_root', None, 'root of log dir')
tf.app.flags.DEFINE_string(
'logdir', None, 'log dir')
tf.app.flags.DEFINE_string(
'restore_from', None, 'restore from dir (not from *.ckpt)')
tf.app.flags.DEFINE_string('gpu_cfg', None, 'GPU configuration')
tf.app.flags.DEFINE_integer('summary_freq', 1000, 'Update summary')
tf.app.flags.DEFINE_string(
'ckpt', None, 'specify the ckpt in restore_from (if there are multiple ckpts)') # TODO
tf.app.flags.DEFINE_string(
'architecture', 'architecture-vawgan-vcc2016.json', 'network architecture')
tf.app.flags.DEFINE_string('model_module', 'model.vae', 'Model module')
tf.app.flags.DEFINE_string('model', None, 'Model: ConvVAE, VAWGAN')
tf.app.flags.DEFINE_string('trainer_module', 'trainer.vae', 'Trainer module')
tf.app.flags.DEFINE_string('trainer', None, 'Trainer: VAETrainer, VAWGANTrainer')
if args.model is None or args.trainer is None:
raise ValueError(
'\n Both `model` and `trainer` should be assigned.' +\
'\n Use `python main.py --help` to see applicable options.'
)
module = import_module(args.model_module, package=None)
MODEL = getattr(module, args.model)
module = import_module(args.trainer_module, package=None)
TRAINER = getattr(module, args.trainer)
def main():
''' NOTE: The input is rescaled to [-1, 1] '''
dirs = validate_log_dirs(args)
tf.gfile.MakeDirs(dirs['logdir'])
with open(args.architecture) as f:
arch = json.load(f)
with open(os.path.join(dirs['logdir'], args.architecture), 'w') as f:
json.dump(arch, f, indent=4)
normalizer = Tanhize(
xmax=np.fromfile('./etc/xmax.npf'),
xmin=np.fromfile('./etc/xmin.npf'),
)
image, label = read(
file_pattern=arch['training']['datadir'],
batch_size=arch['training']['batch_size'],
capacity=2048,
min_after_dequeue=1024,
normalizer=normalizer,
)
machine = MODEL(arch)
loss = machine.loss(image, label)
trainer = TRAINER(loss, arch, args, dirs)
trainer.train(nIter=arch['training']['max_iter'], machine=machine)
if __name__ == '__main__':
main()