"Make Keras a Bit Cooler!": Reuse Techniques for Layers and Models

By 苏剑林 | September 29, 2019

Today we continue to dig deep into Keras, once again experiencing its unparalleled elegant design. This time, our focus is on "reuse," primarily the repetitive use of layers and models.

Generally, reuse is pursued for two goals: first, to share weights—meaning that two layers not only function the same but also share weights and update synchronously; second, to avoid rewriting code—for example, when we have already built a model and want to decompose it to construct sub-models, etc.

Basics

In fact, Keras has already considered many of these aspects for us, so in many cases, mastering the basic usage is sufficient to meet most of our needs.

Layer Reuse

Layer reuse is the simplest: initialize a layer, store it, and then call it repeatedly:

x_in = Input(shape=(784,))
x = x_in

layer = Dense(784, activation='relu') # Initialize a layer and store it

x = layer(x) # First call
x = layer(x) # Subsequent call
x = layer(x) # Subsequent call

It is important to note that you must first initialize a layer and store it as a variable before calling it to ensure that the repeated calls share weights. Conversely, if the code follows the form below, the weights are not shared:

x = Dense(784, activation='relu')(x)
x = Dense(784, activation='relu')(x) # Does not share weights with the previous one
x = Dense(784, activation='relu')(x) # Does not share weights with the previous ones

Model Reuse

Keras models behave similarly to layers; they can be called in the same way as layers. For example:

x_in = Input(shape=(784,))
x = x_in

x = Dense(10, activation='softmax')(x)

model = Model(x_in, x) # Build model

x_in = Input(shape=(100,))
x = x_in

x = Dense(784, activation='relu')(x)
x = model(x) # Use the model like a layer

model2 = Model(x_in, x)

Friends who have read the Keras source code will understand that the reason a model can be used like a layer is that Model itself inherits from the Layer class, so models naturally inherit certain characteristics from layers.

Model Cloning

Model cloning is similar to model reuse, except that the resulting new model does not share weights with the original model. In other words, only the exact same model structure is preserved, and the updates between the two models are independent. Keras provides a dedicated function for cloning models, which can be called directly:

from keras.models import clone_model

model2 = clone_model(model1)

Note that clone_model completely copies the architecture of the original model and reconstructs a new model, but it does not copy the values of the original model's weights. That is to say, for the same input, the results of model1.predict and model2.predict will be different.

If you need to transfer the weights as well, you need to manually set_weights them:

model2.set_weights(K.batch_get_value(model1.weights))

Advanced

The above discussions involved calling existing layers or models exactly as they are, which is relatively straightforward as Keras has already prepared for it. Below are some more complex examples.

Cross-referencing

Cross-referencing here refers to using the weights of an existing layer when defining a new layer. Note that this custom layer might have a completely different function from the original layer; they simply share a certain weight. For example, in BERT, when training the MLM (Masked Language Model), the final fully connected layer that predicts word probabilities shares its weights with the Embedding layer.

A reference implementation is as follows:

class EmbeddingDense(Layer):
    """Operation identical to Dense, but the kernel uses the embedding matrix from an Embedding layer
    """
    def __init__(self, embedding_layer, activation='softmax', **kwargs):
        super(EmbeddingDense, self).__init__(**kwargs)
        self.kernel = K.transpose(embedding_layer.embeddings)
        self.activation = activation
        self.units = K.int_shape(self.kernel)[1]

    def build(self, input_shape):
        super(EmbeddingDense, self).build(input_shape)
        self.bias = self.add_weight(name='bias',
                                    shape=(self.units,),
                                    initializer='zeros')

    def call(self, inputs):
        outputs = K.dot(inputs, self.kernel)
        outputs = K.bias_add(outputs, self.bias)
        outputs = Activation(self.activation).call(outputs)
        return outputs

    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.units,)

# Usage
embedding_layer = Embedding(10000, 128)
x = embedding_layer(x) # Call Embedding layer
x = EmbeddingDense(embedding_layer)(x) # Call EmbeddingDense layer

Extracting Intermediate Layers

Sometimes we need to extract features from intermediate layers of a pre-built model and construct a new model. In Keras, this is also a very simple operation:

from keras.applications.resnet50 import ResNet50
model = ResNet50(weights='imagenet')

Model(
    inputs=model.input,
    outputs=[
        model.get_layer('res5a_branch1').output,
        model.get_layer('activation_47').output,
    ]
)

Splitting from the Middle

Finally, we come to the most challenging part of this article: splitting a model from the middle. Once you understand this, you can also implement operations like inserting or replacing layers in an existing model. This requirement might seem unusual, but believe it or not, someone has asked about it on StackOverflow, suggesting it certainly has value.

Suppose we have an existing model that can be decomposed as: $$ \text{inputs} \to h_1 \to h_2 \to h_3 \to h_4 \to \text{outputs} $$ We might need to replace $h_2$ with a new input and then connect the subsequent layers to build a new model, i.e., the function of the new model would be: $$ \text{inputs} \to h_3 \to h_4 \to \text{outputs} $$ If it were a Sequential model, it would be quite simple: just iterate through model.layers to build the new model:

x_in = Input(shape=(100,))
x = x_in

for layer in model.layers[2:]:
    x = layer(x)

model2 = Model(x_in, x)

However, if the model has a more complex structure, such as a residual structure that doesn't follow a single linear path, it’s not that simple. In reality, this task isn't inherently difficult; the required code is already part of Keras, but a ready-made interface isn't provided. Why do I say this? Because when we call an existing model using code like model(x), Keras essentially reconstructs the entire existing model from start to finish. Since it can rebuild the entire model, building "half" a model is technically feasible; there just isn't a public API for it. For technical details, refer to the run_internal_graph function in keras/engine/network.py within the Keras source code.

The logic for fully reconstructing a model resides in the run_internal_graph function. It’s not a simple function, so it’s best not to rewrite it. If we want to use that logic to split a model, the only way is to "graft" a solution: modify certain attributes of the existing model to trick the run_internal_graph function into thinking the model's input layer is an intermediate layer rather than the original input layer. With this idea in mind and a careful reading of the run_internal_graph code, one can derive the following reference code:

def get_outputs_of(model, start_tensors, input_layers=None):
    """start_tensors is the position to split from
    """
    # Create a new model for this operation
    model = Model(inputs=model.input,
                  outputs=model.output,
                  name='outputs_of_' + model.name)
    # Adaptation for convenience
    if not isinstance(start_tensors, list):
        start_tensors = [start_tensors]
    if input_layers is None:
        input_layers = [
            Input(shape=K.int_shape(x)[1:], dtype=K.dtype(x))
            for x in start_tensors
        ]
    elif not isinstance(input_layers, list):
        input_layers = [input_layers]
    # Core: Overwrite the model's input
    model.inputs = start_tensors
    model._input_layers = [x._keras_history[0] for x in input_layers]
    # Adaptation for convenience
    if len(input_layers) == 1:
        input_layers = input_layers[0]
    # Organize layers, referenced from Model's run_internal_graph function
    layers, tensor_map = [], set()
    for x in model.inputs:
        tensor_map.add(str(id(x)))
    depth_keys = list(model._nodes_by_depth.keys())
    depth_keys.sort(reverse=True)
    for depth in depth_keys:
        nodes = model._nodes_by_depth[depth]
        for node in nodes:
            n = 0
            for x in node.input_tensors:
                if str(id(x)) in tensor_map:
                    n += 1
            if n == len(node.input_tensors):
                if node.outbound_layer not in layers:
                    layers.append(node.outbound_layer)
                for x in node.output_tensors:
                    tensor_map.add(str(id(x)))
    model._layers = layers # Keep only the used layers
    # Calculate outputs
    outputs = model(input_layers)
    return input_layers, outputs

Usage:

from keras.applications.resnet50 import ResNet50
model = ResNet50(weights='imagenet')

x, y = get_outputs_of(
    model,
    model.get_layer('add_15').output
)

model2 = Model(x, y)

The code is a bit long, but the logic is actually simple. The truly core code consists of just three lines:

model.inputs = start_tensors
model._input_layers = [x._keras_history[0] for x in input_layers]
outputs = model(input_layers)

By overriding the model's model.inputs and model._input_layers, we achieve the effect of tricking the model into rebuilding from an intermediate layer. The rest is mostly adaptation work, and model._layers = layers ensures that only the layers used starting from the middle are retained. This is used to ensure the accuracy of the model's parameter count; if this part were removed, the model's reported parameter count would still equal that of the original entire model.

Summary

Keras is the most aesthetically pleasing deep learning framework—at least in terms of code readability, it stands alone. While some readers might mention PyTorch, and granted PyTorch has its advantages, in terms of readability, I believe it does not equal Keras.

Through deep investigation of Keras, I have not only marveled at the profound and elegant programming skills of its authors but also felt my own programming skills improve. Indeed, many of my Python programming techniques were learned from studying the Keras source code.