6.3  Parameter Initialization

Now that we know how to access the parameters, let’s look at how to initialize them properly. We discussed the need for proper initialization in Section 5.4. The deep learning framework provides default random initializations to its layers. However, we often want to initialize our weights according to various other protocols. The framework provides most commonly used protocols, and also allows to create a custom initializer.

import torch
from torch import nn
import tensorflow as tf
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
from mxnet import init, np, npx
from mxnet.gluon import nn
npx.set_np()

By default, PyTorch initializes weight and bias matrices uniformly by drawing from a range that is computed according to the input and output dimension. PyTorch’s nn.init module provides a variety of preset initialization methods.

By default, Keras initializes weight matrices uniformly by drawing from a range that is computed according to the input and output dimension, and the bias parameters are all set to zero. TensorFlow provides a variety of initialization methods both in the root module and the keras.initializers module.

By default, Flax initializes weights using jax.nn.initializers.lecun_normal, i.e., by drawing samples from a truncated normal distribution centered on 0 with the standard deviation set as the squared root of \(1 / \textrm{fan}_{\textrm{in}}\) where fan_in is the number of input units in the weight tensor. The bias parameters are all set to zero. Jax’s nn.initializers module provides a variety of preset initialization methods.

By default, MXNet initializes weight parameters by randomly drawing from a uniform distribution \(U(-0.07, 0.07)\), clearing bias parameters to zero. MXNet’s init module provides a variety of preset initialization methods.

net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))
X = torch.rand(size=(2, 4))
net(X).shape
torch.Size([2, 1])
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(4, activation=tf.nn.relu),
    tf.keras.layers.Dense(1),
])

X = tf.random.uniform((2, 4))
net(X).shape
TensorShape([2, 1])
net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])
X = jax.random.uniform(d2l.get_key(), (2, 4))
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
(2, 1)
net = nn.Sequential()
net.add(nn.Dense(8, activation='relu'))
net.add(nn.Dense(1))
net.initialize()  # Use the default initialization method

X = np.random.uniform(size=(2, 4))
net(X).shape
(2, 1)

6.3.1 Built-in Initialization

Let’s begin by calling on built-in initializers. The code below initializes all weight parameters as Gaussian random variables with standard deviation 0.01, while bias parameters are cleared to zero.

def init_normal(module):
    if type(module) == nn.Linear:
        nn.init.normal_(module.weight, mean=0, std=0.01)
        nn.init.zeros_(module.bias)

net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]
(tensor([-0.0020,  0.0051, -0.0077, -0.0037]), tensor(0.))
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4, activation=tf.nn.relu,
        kernel_initializer=tf.random_normal_initializer(mean=0, stddev=0.01),
        bias_initializer=tf.zeros_initializer()),
    tf.keras.layers.Dense(1)])

net(X)
net.weights[0], net.weights[1]
(<Variable path=sequential_1/dense_2/kernel, shape=(4, 4), dtype=float32, value=[[ 0.00841248  0.00462559  0.00900105 -0.00837825]
  [-0.01620247  0.00945647 -0.00324482  0.00845975]
  [-0.01536592 -0.00469847  0.01590085  0.00454095]
  [-0.00404854  0.00096786 -0.00723826  0.00587411]]>,
 <Variable path=sequential_1/dense_2/bias, shape=(4,), dtype=float32, value=[0. 0. 0. 0.]>)
weight_init = nn.initializers.normal(0.01)
bias_init = nn.initializers.zeros

net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
(Array([-0.01066845, -0.0009537 ,  0.00432578, -0.016987  ], dtype=float32),
 Array(0., dtype=float32))
# Here force_reinit ensures that parameters are freshly initialized even if
# they were already initialized previously
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)
net[0].weight.data()[0]
array([ 0.00354961, -0.00614133,  0.0107317 ,  0.01830765])

We can also initialize all the parameters to a given constant value (say, 1).

def init_constant(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 1)
        nn.init.zeros_(module.bias)

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]
(tensor([1., 1., 1., 1.]), tensor(0.))
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4, activation=tf.nn.relu,
        kernel_initializer=tf.keras.initializers.Constant(1),
        bias_initializer=tf.zeros_initializer()),
    tf.keras.layers.Dense(1),
])

net(X)
net.weights[0], net.weights[1]
(<Variable path=sequential_2/dense_4/kernel, shape=(4, 4), dtype=float32, value=[[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]>,
 <Variable path=sequential_2/dense_4/bias, shape=(4,), dtype=float32, value=[0. 0. 0. 0.]>)
weight_init = nn.initializers.constant(1)

net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
(Array([1., 1., 1., 1.], dtype=float32), Array(0., dtype=float32))
net.initialize(init=init.Constant(1), force_reinit=True)
net[0].weight.data()[0]
array([1., 1., 1., 1.])

We can also apply different initializers for certain blocks. For example, below we initialize the first layer with the Xavier initializer and initialize the second layer to a constant value of 42.

def init_xavier(module):
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)

def init_42(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 42)

net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)
tensor([ 0.3135,  0.0255, -0.2494, -0.6272])
tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4,
        activation=tf.nn.relu,
        kernel_initializer=tf.keras.initializers.GlorotUniform()),
    tf.keras.layers.Dense(
        1, kernel_initializer=tf.keras.initializers.Constant(42)),
])

net(X)
print(net.layers[1].weights[0])
print(net.layers[2].weights[0])
<Variable path=sequential_3/dense_6/kernel, shape=(4, 4), dtype=float32, value=[[-0.11885554  0.47910386 -0.35207492 -0.82871675]
 [-0.03338087 -0.8190652  -0.12931484  0.6246043 ]
 [ 0.3715145   0.16259259 -0.44517642  0.36619192]
 [-0.52023375 -0.02970785  0.29457992  0.05001092]]>
<Variable path=sequential_3/dense_7/kernel, shape=(4, 1), dtype=float32, value=[[42.]
 [42.]
 [42.]
 [42.]]>
net = nn.Sequential([nn.Dense(8, kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=nn.initializers.constant(42),
                              bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
params['params']['layers_0']['kernel'][:, 0], params['params']['layers_2']['kernel']
(Array([-0.1458586 ,  0.09845318, -0.01580709,  0.68841463], dtype=float32),
 Array([[42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.]], dtype=float32))
net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
net[1].initialize(init=init.Constant(42), force_reinit=True)
print(net[0].weight.data()[0])
print(net[1].weight.data())
[-0.26102373  0.15249556 -0.19274211 -0.24742058]
[[42. 42. 42. 42. 42. 42. 42. 42.]]

6.3.1.1 Custom Initialization

Sometimes, the initialization methods we need are not provided by the deep learning framework. In the example below, we define an initializer for any weight parameter \(w\) using the following strange distribution:

\[ \begin{aligned} w \sim \begin{cases} U(5, 10) & \textrm{ with probability } \frac{1}{4} \\ 0 & \textrm{ with probability } \frac{1}{2} \\ U(-10, -5) & \textrm{ with probability } \frac{1}{4} \end{cases} \end{aligned} \]

Again, we implement a my_init function to apply to net.

Here we define a subclass of Initializer and implement the __call__ function that return a desired tensor given the shape and data type.

Jax initialization functions take as arguments the PRNGKey, shape and dtype. Here we implement the function my_init that returns a desired tensor given the shape and data type.

Here we define a subclass of the Initializer class. Usually, we only need to implement the _init_weight function which takes a tensor argument (data) and assigns to it the desired initialized values.

def my_init(module):
    if type(module) == nn.Linear:
        print("Init", *[(name, param.shape)
                        for name, param in module.named_parameters()][0])
        nn.init.uniform_(module.weight, -10, 10)
        module.weight.data *= module.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]
Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])
tensor([[ 0.0000, -7.9698, -0.0000,  0.0000],
        [-0.0000, -9.9002,  0.0000, -0.0000]], grad_fn=<SliceBackward0>)
class MyInit(tf.keras.initializers.Initializer):
    def __call__(self, shape, dtype=None):
        data=tf.random.uniform(shape, -10, 10, dtype=dtype)
        factor=(tf.abs(data) >= 5)
        factor=tf.cast(factor, tf.float32)
        return data * factor

net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4,
        activation=tf.nn.relu,
        kernel_initializer=MyInit()),
    tf.keras.layers.Dense(1),
])

net(X)
print(net.layers[1].weights[0])
<Variable path=sequential_4/dense_8/kernel, shape=(4, 4), dtype=float32, value=[[ 8.718412  -0.        -6.8890595 -8.556049 ]
 [-6.912656   9.152588   6.130512   7.3861103]
 [ 0.        -8.536215   0.         9.882376 ]
 [ 7.6348896  0.        -0.         0.       ]]>
def my_init(key, shape, dtype=jnp.float_):
    data = jax.random.uniform(key, shape, minval=-10, maxval=10)
    return data * (jnp.abs(data) >= 5)

net = nn.Sequential([nn.Dense(8, kernel_init=my_init), nn.relu, nn.Dense(1)])
params = net.init(d2l.get_key(), X)
print(params['params']['layers_0']['kernel'][:, :2])
[[-7.4473286  5.2821994]
 [-8.338981   0.       ]
 [ 0.         0.       ]
 [ 9.397041   0.       ]]
class MyInit(init.Initializer):
    def _init_weight(self, name, data):
        print('Init', name, data.shape)
        data[:] = np.random.uniform(-10, 10, data.shape)
        data *= np.abs(data) >= 5

net.initialize(MyInit(), force_reinit=True)
net[0].weight.data()[:2]
Init dense0_weight (8, 4)
Init dense1_weight (1, 8)
array([[-6.0683527,  8.991421 , -0.       ,  0.       ],
       [ 6.4198647, -9.728567 , -8.057975 ,  0.       ]])

Note that we always have the option of setting parameters directly.

Note that we always have the option of setting parameters directly.

When initializing parameters in JAX and Flax, the the dictionary of parameters returned has a flax.core.frozen_dict.FrozenDict type. It is not advisable in the Jax ecosystem to directly alter the values of an array, hence the datatypes are generally immutable. One might use params.unfreeze() to make changes.

Note that we always have the option of setting parameters directly.

net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]
tensor([42.0000, -6.9698,  1.0000,  1.0000])
net.layers[1].weights[0][:].assign(net.layers[1].weights[0] + 1)
net.layers[1].weights[0][0, 0].assign(42)
net.layers[1].weights[0]
<Variable path=sequential_4/dense_8/kernel, shape=(4, 4), dtype=float32, value=[[42.         1.        -5.8890595 -7.5560493]
 [-5.912656  10.152588   7.130512   8.38611  ]
 [ 1.        -7.536215   1.        10.882376 ]
 [ 8.63489    1.         1.         1.       ]]>
net[0].weight.data()[:] += 1
net[0].weight.data()[0, 0] = 42
net[0].weight.data()[0]
array([42.      ,  9.991421,  1.      ,  1.      ])

6.3.2 Summary

We can initialize parameters using built-in and custom initializers.

6.3.3 Exercises

Look up the online documentation for more built-in initializers.