Skip to content

Commit

Permalink
Added method to replace BN layers
Browse files Browse the repository at this point in the history
  • Loading branch information
andreped committed Sep 8, 2023
1 parent c90109b commit 7769b41
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions gradient_accumulator/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import tensorflow as tf
from .layers import AccumBatchNormalization

Check warning on line 2 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L1-L2

Added lines #L1 - L2 were not covered by tests


def replace_batchnorm_layers(model, accum_steps, position='after'):

Check warning on line 5 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L5

Added line #L5 was not covered by tests
# Auxiliary dictionary to describe the network graph
network_dict = {'input_layers_of': {}, 'new_output_tensor_of': {}}

Check warning on line 7 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L7

Added line #L7 was not covered by tests

# Set the input layers of each layer
for layer in model.layers:
for node in layer._outbound_nodes:
layer_name = node.outbound_layer.name
if layer_name not in network_dict['input_layers_of']:
network_dict['input_layers_of'].update(

Check warning on line 14 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L10-L14

Added lines #L10 - L14 were not covered by tests
{layer_name: [layer.name]})
else:
network_dict['input_layers_of'][layer_name].append(layer.name)

Check warning on line 17 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L17

Added line #L17 was not covered by tests

# Set the output tensor of the input layer
network_dict['new_output_tensor_of'].update({model.layers[0].name: model.input})

Check warning on line 20 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L20

Added line #L20 was not covered by tests

# Iterate over all layers after the input
model_outputs = []
iter_ = 0
for layer in model.layers[1:]:

Check warning on line 25 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L23-L25

Added lines #L23 - L25 were not covered by tests

# Determine input tensors
layer_input = [network_dict['new_output_tensor_of'][layer_aux] for layer_aux in network_dict['input_layers_of'][layer.name]]
if len(layer_input) == 1:
layer_input = layer_input[0]

Check warning on line 30 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L28-L30

Added lines #L28 - L30 were not covered by tests

# Insert layer if name matches
if isinstance(layer, tf.keras.layers.BatchNormalization):
if position == 'replace':
x = layer_input

Check warning on line 35 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L33-L35

Added lines #L33 - L35 were not covered by tests
else:
raise ValueError('position must be: replace')

Check warning on line 37 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L37

Added line #L37 was not covered by tests

# Get weights of current layer
old_weights = layer.get_weights()

Check warning on line 40 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L40

Added line #L40 was not covered by tests

# build new layer
new_layer = AccumBatchNormalization(accum_steps=accum_steps, name="AccumBatchNormalization_" + str(iter_))
new_layer.build(input_shape=layer.input_shape)

Check warning on line 44 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L43-L44

Added lines #L43 - L44 were not covered by tests

iter_ += 1

Check warning on line 46 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L46

Added line #L46 was not covered by tests

# set weights in new layer to match old layer
new_layer.accum_mean = layer.moving_mean
new_layer.moving_mean = layer.moving_mean

Check warning on line 50 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L49-L50

Added lines #L49 - L50 were not covered by tests

new_layer.accum_variance = layer.moving_variance
new_layer.moving_variance = layer.moving_variance

Check warning on line 53 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L52-L53

Added lines #L52 - L53 were not covered by tests

# forward step
x = new_layer(x)

Check warning on line 56 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L56

Added line #L56 was not covered by tests

else:
x = layer(layer_input)

Check warning on line 59 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L59

Added line #L59 was not covered by tests

# Set new output tensor (the original one, or the one of the inserted layer)
network_dict['new_output_tensor_of'].update({layer.name: x})

Check warning on line 62 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L62

Added line #L62 was not covered by tests

# Save tensor in output list if it is output in initial model
if layer_name in model.output_names:
model_outputs.append(x)

Check warning on line 66 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L65-L66

Added lines #L65 - L66 were not covered by tests

return tf.keras.Model(inputs=model.inputs, outputs=x)

Check warning on line 68 in gradient_accumulator/utils.py

View check run for this annotation

Codecov / codecov/patch

gradient_accumulator/utils.py#L68

Added line #L68 was not covered by tests

0 comments on commit 7769b41

Please sign in to comment.