Run this notebook online:Binder or Colab: Colab

5.2. Parameter Management

Once we have chosen an architecture and set our hyperparameters, we proceed to the training loop, where our goal is to find parameter values that minimize our objective function. After training, we will need these parameters in order to make future predictions. Additionally, we will sometimes wish to extract the parameters either to reuse them in some other context, to save our model to disk so that it may be exectuted in other software, or for examination in the hopes of gaining scientific understanding.

Most of the time, we will be able to ignore the nitty-gritty details of how parameters are declared and manipulated, relying on DJL to do the heavy lifting. However, when we move away from stacked architectures with standard layers, we will sometimes need to get into the weeds of declaring and manipulating parameters. In this section, we cover the following:

  • Accessing parameters for debugging, diagnostics, and visualiziations.

  • Parameter initialization.

  • Sharing parameters across different model components.

We start by focusing on an MLP with one hidden layer.

%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.7.0-SNAPSHOT
%maven ai.djl:model-zoo:0.7.0-SNAPSHOT
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26

%maven net.java.dev.jna:jna:5.3.0
%maven ai.djl.mxnet:mxnet-engine:0.7.0-SNAPSHOT
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-a
import ai.djl.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.types.*;
import ai.djl.ndarray.index.*;
import ai.djl.nn.*;
import ai.djl.nn.core.*;
import ai.djl.training.*;
import ai.djl.training.initializer.*;
import ai.djl.training.dataset.*;
import ai.djl.util.*;
import ai.djl.translate.*;
import ai.djl.inference.Predictor;
NDManager manager = NDManager.newBaseManager();

NDArray x = manager.randomUniform(0, 1, new Shape(2, 4));

Model model = Model.newInstance("lin-reg");

SequentialBlock net = new SequentialBlock();

net.add(Linear.builder().setUnits(8).build());
net.add(Activation.reluBlock());
net.add(Linear.builder().setUnits(1).build());
net.setInitializer(new NormalInitializer());
net.initialize(manager, DataType.FLOAT32, x.getShape());

model.setBlock(net);

Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());

predictor.predict(new NDList(x)).singletonOrThrow(); // forward computation
ND: (2, 1) gpu(0) float32
[[-2.03669551e-05],
 [-1.32092864e-05],
]

5.2.1. Parameter Access

Let us start with how to access parameters from the models that you already know. Each layer’s parameters are conveniently stored in a Pair<String, Parameter> consisting of a unique String that serves as a key for the layer and the Parameter itself. The ParameterList is an extension of PairList and is returned with a call to the getParameters() method on a Block. We can inspect the parameters of the net defined above. When a model is defined via the SequentialBlock class, we can access any layer’s Pair<String, Parameter> by calling get() on the ParameterList and passing in the index of the parameter we want. Calling getKey() and getValue() on a Pair<String, Parameter> will get the parameter’s name and Parameter respectively. We can also directly get the Parameter we want from the ParameterList by calling get() and passing in its unique key(the String portion of the Pair<String, Parameter>. If we call valueAt() and pass in the index, we will get the Parameter directly as well.

ParameterList params = net.getParameters();
// Print out all the keys (unique!)
for (var pair : params) {
    System.out.println(pair.getKey());
}

// Use the unique key to access the Parameter
NDArray dense0Weight = params.get("01Linear_weight").getArray();
NDArray dense0Bias = params.get("01Linear_bias").getArray();

// Use indexing to access the Parameter
NDArray dense1Weight = params.valueAt(2).getArray();
NDArray dense1Bias = params.valueAt(3).getArray();

System.out.println(dense0Weight);
System.out.println(dense0Bias);

System.out.println(dense1Weight);
System.out.println(dense1Bias);
01Linear_weight
01Linear_bias
03Linear_weight
03Linear_bias
weight: (8, 4) gpu(0) float32 hasGradient
[[ 0.0014, -0.0122,  0.0031,  0.0111],
 [-0.0004, -0.0071, -0.0129, -0.0088],
 [-0.0006, -0.0082,  0.0143, -0.0013],
 [ 0.0028,  0.0083, -0.0075, -0.0138],
 [ 0.01  , -0.0114, -0.0035,  0.0054],
 [-0.015 , -0.0122,  0.0124, -0.0027],
 [-0.0147, -0.0099,  0.0028,  0.0095],
 [ 0.0079, -0.0132,  0.0047,  0.0124],
]

bias: (8) gpu(0) float32 hasGradient
[0., 0., 0., 0., 0., 0., 0., 0.]

weight: (1, 8) gpu(0) float32 hasGradient
[[ 0.0084,  0.0148,  0.0031,  0.004 , -0.0089,  0.0029, -0.0037, -0.0014],
]

bias: (1) gpu(0) float32 hasGradient
[0.]

The output tells us a few important things. First, each fully-connected layer has two parameters, e.g., dense0Weight and dense0Bias, corresponding to that layer’s weights and biases, respectively. The params variable is a ParameterList which contain the key-value pairs of the layer name and a parameter of the Parameter class. With a Parameter, we can get the underlying numerical values as NDArrays by calling getArray() on them! Both the weights and biases are stored as single precision floats(FLOAT32).

5.2.1.1. Targeted Parameters

Parameters are complex objects, containing data, gradients, and additional information. That’s why we need to request the data explicitly. Note that the bias vector consists of zeroes because we have not updated the network since it was initialized.

Note that unlike the biases, the weights are nonzero. This is because unlike biases, weights are initialized randomly. In addition to getArray(), each Parameter also provides a requireGradient() method which returns whether the parameter needs gradients to be computed (which we set on the NDArray with attachGradient()). The gradient has the same shape as the weight. To actually access the gradient, we simply call getGradient() on the NDArray. Because we have not invoked backpropagation for this network yet, its values are all 0. We would invoke it by creating a GradientCollector instance and run our calculations inside it.

dense0Weight.getGradient();
ND: (8, 4) gpu(0) float32
[[0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
]

5.2.1.2. Collecting Parameters from Nested Blocks

Let us see how the parameter naming conventions work if we nest multiple blocks inside each other. For that we first define a function that produces Blocks (a Block factory, so to speak) and then combine these inside yet larger Blocks.

public SequentialBlock block1() {
    SequentialBlock net = new SequentialBlock();
    net.add(Linear.builder().setUnits(32).build());
    net.add(Activation.reluBlock());
    net.add(Linear.builder().setUnits(16).build());
    net.add(Activation.reluBlock());
    return net;
}

public SequentialBlock block2() {
    SequentialBlock net = new SequentialBlock();
    for (int i = 0; i < 4; i++) {
        net.add(block1());
    }
    return net;
}

SequentialBlock rgnet = new SequentialBlock();
rgnet.add(block2());
rgnet.add(Linear.builder().setUnits(10).build());
rgnet.setInitializer(new NormalInitializer());
rgnet.initialize(manager, DataType.FLOAT32, x.getShape());

Model model = Model.newInstance("rgnet");
model.setBlock(rgnet);

Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());

predictor.predict(new NDList(x)).singletonOrThrow();
ND: (2, 10) gpu(0) float32
[[-9.05861164e-15, -1.80095078e-14, -2.33998527e-14, -1.86868902e-14,  7.10750259e-15,  5.75573922e-15,  9.72335378e-16,  1.06593548e-14,  9.80970201e-15, -8.17016641e-15],
 [-4.27109291e-15, -7.85593921e-15, -9.57490109e-15, -7.16382689e-15,  2.99069440e-15,  2.62443375e-15,  6.40666075e-16,  4.29879427e-15,  4.13538595e-15, -3.19015266e-15],
]

Now that we have designed the network, let us see how it is organized. We can get the list of named parameters by calling getParameters(). However, we not only want to see the parameters, but also how our network is structured. To see our network architecture, we can simply print out the block whose architecture we want to see.

/* Network Architecture for RgNet */
rgnet
Sequential(
    Sequential(
            Sequential(
                    Linear(2 -> (2, 32))
                    Lambda()
                    Linear(2 -> (2, 16))
                    Lambda()
            )
            Sequential(
                    Linear(2 -> (2, 32))
                    Lambda()
                    Linear(2 -> (2, 16))
                    Lambda()
            )
            Sequential(
                    Linear(2 -> (2, 32))
                    Lambda()
                    Linear(2 -> (2, 16))
                    Lambda()
            )
            Sequential(
                    Linear(2 -> (2, 32))
                    Lambda()
                    Linear(2 -> (2, 16))
                    Lambda()
            )
    )
    Linear(2 -> (2, 10))
)
/* Parameters for RgNet */
for (var param : rgnet.getParameters()) {
    System.out.println(param.getValue().getArray());
}
weight: (32, 4) gpu(0) float32 hasGradient
[ Exceed max print size ]
bias: (32) gpu(0) float32 hasGradient
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., ... 12 more]

weight: (16, 32) gpu(0) float32 hasGradient
[ Exceed max print size ]
bias: (16) gpu(0) float32 hasGradient
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]

weight: (32, 16) gpu(0) float32 hasGradient
[ Exceed max print size ]
bias: (32) gpu(0) float32 hasGradient
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., ... 12 more]

weight: (16, 32) gpu(0) float32 hasGradient
[ Exceed max print size ]
bias: (16) gpu(0) float32 hasGradient
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]

weight: (32, 16) gpu(0) float32 hasGradient
[ Exceed max print size ]
bias: (32) gpu(0) float32 hasGradient
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., ... 12 more]

weight: (16, 32) gpu(0) float32 hasGradient
[ Exceed max print size ]
bias: (16) gpu(0) float32 hasGradient
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]

weight: (32, 16) gpu(0) float32 hasGradient
[ Exceed max print size ]
bias: (32) gpu(0) float32 hasGradient
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., ... 12 more]

weight: (16, 32) gpu(0) float32 hasGradient
[ Exceed max print size ]
bias: (16) gpu(0) float32 hasGradient
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]

weight: (10, 16) gpu(0) float32 hasGradient
[ Exceed max print size ]
bias: (10) gpu(0) float32 hasGradient
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]

Since the layers are hierarchically nested, we can also access them by calling their getChildren() method to get a BlockList(also an extension of PairList) of their inner blocks. It shares methods with ParameterList and as such we can use their familiar structure to access the blocks. We can call get(i) to get the Pair<String, Block> at the index i we want, and then finally getValue() to get the actual block. We can do this in one step as shown above with valueAt(i). Then we have to repeat that to get that blocks child and so on.

Here, we access the first major block, within it the second subblock, and within that the bias of the first layer, with as follows:

Block majorBlock1 = rgnet.getChildren().get(0).getValue();
Block subBlock2 = majorBlock1.getChildren().valueAt(1);
Block linearLayer1 = subBlock2.getChildren().valueAt(0);
NDArray bias = linearLayer1.getParameters().valueAt(1).getArray();
bias
bias: (32) gpu(0) float32 hasGradient
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., ... 12 more]

5.2.2. Parameter Initialization

Now that we know how to access the parameters, let us look at how to initialize them properly. We discussed the need for initialization in Section 4.8. By default, DJL initializes weight matrices based on your set initializer and the bias parameters are all set to \(0\). However, we will often want to initialize our weights according to various other protocols. DJL’s ai.djl.training.initializer package provides a variety of preset initialization methods. If we want to create a custom initializer, we need to do some extra work.

5.2.2.1. Built-in Initialization

In DJL, when setting the initializer for blocks, the default setInitializer() function does not overwrite any previous set initializers. So if you set an initializer earlier, but decide you want to change your initializer and call setInitializer() again, the second setInitializer() will NOT overwrite your first one.

Additionally, when you call setInitializer() on a block, all internal blocks will also call setInitializer() with the same given initializer.

This means that we can call setInitializer() on the highest level of a block and know that all internal blocks that do not have an initializer already set will be set to that given initializer.

This setup has the advantage that we don’t have to worry about our setInitializer() overriding our previous initializers on internal blocks!

If you want to however, you can explicitly set an initializer for a Parameter by calling its setInitializer() function directly and passing in true to the overwrite input. Simply loop over all the parameters returned from getParameters() and set their initializers directly!

Let us begin by calling on built-in initializers. The code below initializes all parameters to a given constant value 1, by using the ConstantInitializer() initializer.

Note that this will not do anything currently since we have already set our initializer in the previous code block. We can verify this by checking the weight of a parameter.

net.setInitializer(new ConstantInitializer(1));
net.initialize(manager, DataType.FLOAT32, x.getShape());
Block linearLayer = net.getChildren().get(0).getValue();
NDArray weight = linearLayer.getParameters().get(0).getValue().getArray();
weight
weight: (8, 4) gpu(0) float32 hasGradient
[[ 0.0014, -0.0122,  0.0031,  0.0111],
 [-0.0004, -0.0071, -0.0129, -0.0088],
 [-0.0006, -0.0082,  0.0143, -0.0013],
 [ 0.0028,  0.0083, -0.0075, -0.0138],
 [ 0.01  , -0.0114, -0.0035,  0.0054],
 [-0.015 , -0.0122,  0.0124, -0.0027],
 [-0.0147, -0.0099,  0.0028,  0.0095],
 [ 0.0079, -0.0132,  0.0047,  0.0124],
]

We can see these initializations however if we create a new network. Let us write a function to create these network architectures for us conveniently.

public SequentialBlock getNet() {
    SequentialBlock net = new SequentialBlock();
    net.add(Linear.builder().setUnits(8).build());
    net.add(Activation.reluBlock());
    net.add(Linear.builder().setUnits(1).build());
    return net;
}

If we run our previous initializer on this new net and check a parameter, we’ll see that everything is initialized properly! (to 7777!)

SequentialBlock net = getNet();
net.setInitializer(new ConstantInitializer(7777));
net.initialize(manager, DataType.FLOAT32, x.getShape());
Block linearLayer = net.getChildren().valueAt(0);
NDArray weight = linearLayer.getParameters().valueAt(0).getArray();
weight
weight: (8, 4) gpu(0) float32 hasGradient
[[7777., 7777., 7777., 7777.],
 [7777., 7777., 7777., 7777.],
 [7777., 7777., 7777., 7777.],
 [7777., 7777., 7777., 7777.],
 [7777., 7777., 7777., 7777.],
 [7777., 7777., 7777., 7777.],
 [7777., 7777., 7777., 7777.],
 [7777., 7777., 7777., 7777.],
]

We can also initialize all parameters as Gaussian random variables with standard deviation \(.01\).

SequentialBlock net = getNet();
net.setInitializer(new NormalInitializer());
net.initialize(manager, DataType.FLOAT32, x.getShape());
Block linearLayer = net.getChildren().valueAt(0);
NDArray weight = linearLayer.getParameters().valueAt(0).getArray();
weight
weight: (8, 4) gpu(0) float32 hasGradient
[[-0.0177,  0.0105,  0.0094,  0.0044],
 [-0.0022, -0.0001,  0.0036, -0.004 ],
 [-0.0125, -0.0027,  0.0097,  0.0101],
 [ 0.0065, -0.002 ,  0.0073, -0.0172],
 [ 0.0097,  0.0089, -0.0052, -0.0107],
 [-0.0029,  0.0028, -0.0105, -0.0018],
 [ 0.0054,  0.003 ,  0.002 ,  0.0024],
 [ 0.015 ,  0.0065,  0.0025,  0.0031],
]

We can also apply different initializers for certain Blocks. For example, below we initialize the first layer with the Xavier initializer and initialize the second layer to a constant value of 0.

We will do this without the getNet() function as it will be easier to have the reference to each block we want to set.

SequentialBlock net = new SequentialBlock();
Linear linear1 = Linear.builder().setUnits(8).build();
net.add(linear1);
net.add(Activation.reluBlock());
Linear linear2 = Linear.builder().setUnits(1).build();
net.add(linear2);

linear1.setInitializer(new XavierInitializer());
linear1.initialize(manager, DataType.FLOAT32, x.getShape());

linear2.setInitializer(Initializer.ZEROS);
linear2.initialize(manager, DataType.FLOAT32, x.getShape());

System.out.println(linear1.getParameters().valueAt(0).getArray());
System.out.println(linear2.getParameters().valueAt(0).getArray());
weight: (8, 4) gpu(0) float32 hasGradient
[[ 0.0197, -0.4272,  0.2954,  0.2496],
 [-0.2387, -0.4842,  0.6798,  0.3475],
 [ 0.0094,  0.5641, -0.5202,  0.2189],
 [-0.1938,  0.6563, -0.5584,  0.4464],
 [ 0.4685, -0.6046,  0.5889, -0.5836],
 [-0.0543, -0.2023, -0.3847, -0.2716],
 [ 0.5264, -0.1341,  0.6531, -0.683 ],
 [-0.6154, -0.5642, -0.4575,  0.3256],
]

weight: (1, 4) gpu(0) float32 hasGradient
[[0., 0., 0., 0.],
]

Finally, we can loop over the ParameterList and set their initializers individually. When setting initializers directly on the Parameter, you must pass in an overwrite boolean along with the initializer to declare whether you want your current initializer to overwrite the previous initializer if one has already been set. Here, we do want to overwrite and so pass in true.

For this example, however, since we haven’t set the weight initializers before, there is no initializer to overwrite so we could pass in false and still have the same outcome.

However, since bias parameters are automatically set to initialize at 0, to properly set our intializer here, we have to set overwrite to true.

SequentialBlock net = getNet();
ParameterList params = net.getParameters();
for (int i = 0; i < params.size(); i++) {
    // Here we interleave initializers.
    // We initialize parameters at even indexes to 0
    // and parameters at odd indexes to 2.
    Parameter param = params.valueAt(i);
    if (i % 2 == 0) {
        // All weight parameters happen to be at even indices.
        // We set them to initialize to 0.
        // There is no need to overwrite
        // since no initializer has been set for them previously.
        param.setInitializer(new ConstantInitializer(0), false);
    }
    else {
        // All bias parameters happen to be at odd indices.
        // We set them to initialize to 2.
        // To set the initializer here properly, we must pass in true
        // for overwrite
        // since bias parameters automatically have their
        // initializer set to 0.
        param.setInitializer(new ConstantInitializer(2), true);
    }
}
net.initialize(manager, DataType.FLOAT32, x.getShape());

for (var param : net.getParameters()) {
    System.out.println(param.getKey());
    System.out.println(param.getValue().getArray());
}
01Linear_weight
weight: (8, 4) gpu(0) float32 hasGradient
[[0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
]

01Linear_bias
bias: (8) gpu(0) float32 hasGradient
[2., 2., 2., 2., 2., 2., 2., 2.]

03Linear_weight
weight: (1, 8) gpu(0) float32 hasGradient
[[0., 0., 0., 0., 0., 0., 0., 0.],
]

03Linear_bias
bias: (1) gpu(0) float32 hasGradient
[2.]

5.2.2.2. Custom Initialization

Sometimes, the initialization methods we need are not standard in DJL. In these cases, we can define a class to implement the Initializer interface. We only have to implement the initialize() function, which takes an NDManager, a Shape, and the DataType. We then create the NDArray with the aforementioned Shape and DataType and initialize it to what we want! You can also design your initializer to take in some parameters. Simply declare them as fields in the class and pass them in as inputs to the constructor! In the example below, we define an initializer for the following strange distribution:

(5.2.1)\[\begin{split}\begin{aligned} w \sim \begin{cases} U[5, 10] & \text{ with probability } \frac{1}{4} \\ 0 & \text{ with probability } \frac{1}{2} \\ U[-10, -5] & \text{ with probability } \frac{1}{4} \end{cases} \end{aligned}\end{split}\]
class MyInit implements Initializer {

    public MyInit() {}

    @Override
    public NDArray initialize(NDManager manager, Shape shape, DataType dataType) {
        System.out.printf("Init %s\n", shape.toString());
        // Here we generate data points
        // from a uniform distribution [-10, 10]
        NDArray data = manager.randomUniform(-10, 10, shape, dataType);
        // We keep the data points whose absolute value is >= 5
        // and set the others to 0.
        // This generates the distribution `w` shown above.
        NDArray absGte5 = data.abs().gte(5); // returns boolean NDArray where
                                             // true indicates abs >= 5 and
                                             // false otherwise
        return data.mul(absGte5); // keeps true indices and sets false indices to 0.
                                  // special operation when multiplying a numerical
                                  // NDArray with a boolean NDArray
    }

}
SequentialBlock net = getNet();
net.setInitializer(new MyInit());
net.initialize(manager, DataType.FLOAT32, x.getShape());
Block linearLayer = net.getChildren().valueAt(0);
NDArray weight = linearLayer.getParameters().valueAt(0).getArray();
weight
Init (8, 4)
Init (1, 8)
weight: (8, 4) gpu(0) float32 hasGradient
[[ 5.192 ,  6.1485, -0.    ,  0.    ],
 [-7.9928, -0.    , -0.    , -0.    ],
 [-0.    ,  8.7888,  0.    , -0.    ],
 [-0.    ,  7.804 ,  8.9475, -7.8331],
 [-9.1163, -6.9159,  0.    , -8.4723],
 [ 0.    ,  0.    , -0.    ,  0.    ],
 [-0.    , -8.671 , -0.    , -5.3997],
 [ 8.7472,  0.    , -8.2616,  9.9264],
]

Note that we always have the option of setting parameters directly by calling getValue().getArray() to access the underlying NDArray. A note for advanced users: you cannot directly modify parameters within a GarbageCollector scope. You must modify them outside the GarbageCollector scope to avoid confusing the automatic differentiation mechanics.

// '__'i() is an inplace operation to modify the original NDArray
NDArray weightLayer = net.getChildren().valueAt(0)
    .getParameters().valueAt(0).getArray();
weightLayer.addi(7);
weightLayer.divi(9);
weightLayer.set(new NDIndex(0, 0), 2020); // set the (0, 0) index to 2020
weightLayer;
weight: (8, 4) gpu(0) float32 hasGradient
[[ 2.02000000e+03,  1.46094930e+00,  7.77777791e-01,  7.77777791e-01],
 [-1.10314421e-01,  7.77777791e-01,  7.77777791e-01,  7.77777791e-01],
 [ 7.77777791e-01,  1.75430942e+00,  7.77777791e-01,  7.77777791e-01],
 [ 7.77777791e-01,  1.64489305e+00,  1.77194095e+00, -9.25637856e-02],
 [-2.35141858e-01,  9.33975633e-03,  7.77777791e-01, -1.63584173e-01],
 [ 7.77777791e-01,  7.77777791e-01,  7.77777791e-01,  7.77777791e-01],
 [ 7.77777791e-01, -1.85668409e-01,  7.77777791e-01,  1.77807391e-01],
 [ 1.74969053e+00,  7.77777791e-01, -1.40176028e-01,  1.88070786e+00],
]

5.2.3. Tied Parameters

Often, we want to share parameters across multiple layers. Later we will see that when learning word embeddings, it might be sensible to use the same parameters both for encoding and decoding words. We discussed one such case when we introduced Section 5.1. Let us see how to do this a bit more elegantly. In the following we allocate a dense layer and then use its parameters specifically to set those of another layer.

SequentialBlock net = new SequentialBlock();

// We need to give the shared layer a name
// such that we can reference its parameters
Block shared = Linear.builder().setUnits(8).build();
SequentialBlock sharedRelu = new SequentialBlock();
sharedRelu.add(shared);
sharedRelu.add(Activation.reluBlock());

net.add(Linear.builder().setUnits(8).build());
net.add(Activation.reluBlock());
net.add(sharedRelu);
net.add(sharedRelu);
net.add(Linear.builder().setUnits(10).build());

NDArray x = manager.randomUniform(-10f, 10f, new Shape(2, 20), DataType.FLOAT32);

net.setInitializer(new NormalInitializer());
net.initialize(manager, DataType.FLOAT32, x.getShape());

model.setBlock(net);

Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());
System.out.println(predictor.predict(new NDList(x)).singletonOrThrow());

// Check that the parameters are the same
NDArray shared1 = net.getChildren().valueAt(2)
    .getParameters().valueAt(0).getArray();
NDArray shared2 = net.getChildren().valueAt(3)
    .getParameters().valueAt(0).getArray();
shared1.eq(shared2);
ND: (2, 10) gpu(0) float32
[[ 1.11729014e-06, -4.48667834e-07,  2.15573596e-06, -5.15344709e-08,  8.61305182e-07, -1.66273469e-06, -1.01581463e-06, -7.41474537e-07,  3.13843088e-07, -1.05189793e-06],
 [ 6.84321890e-07,  2.93379884e-07,  7.24803613e-08, -7.04708100e-07,  1.18123033e-07, -6.60354033e-07, -1.39584381e-06,  2.69241358e-08,  4.60836247e-07, -1.21194239e-06],
]
ND: (8, 8) gpu(0) boolean
[[ true,  true,  true,  true,  true,  true,  true,  true],
 [ true,  true,  true,  true,  true,  true,  true,  true],
 [ true,  true,  true,  true,  true,  true,  true,  true],
 [ true,  true,  true,  true,  true,  true,  true,  true],
 [ true,  true,  true,  true,  true,  true,  true,  true],
 [ true,  true,  true,  true,  true,  true,  true,  true],
 [ true,  true,  true,  true,  true,  true,  true,  true],
 [ true,  true,  true,  true,  true,  true,  true,  true],
]

This example shows that the parameters of the second and third layer are tied. They are not just equal, they are represented by the same exact NDArray. Thus, if we change one of the parameters, the other one changes, too. You might wonder, when parameters are tied what happens to the gradients? Since the model parameters contain gradients, the gradients of the second hidden layer and the third hidden layer are added together in shared.getGradient() during backpropagation.

5.2.4. Summary

  • We have several ways to access, initialize, and tie model parameters.

  • We can use custom initialization.

  • DJL has a sophisticated mechanism for accessing parameters in a unique and hierarchical manner.

5.2.5. Exercises

  1. Use the FancyMLP defined in Section 5.1 and access the parameters of the various layers.

  2. Look at the DJL documentation and explore different initializers.

  3. Try accessing the model parameters after net.initialize() and before predictor.predict(x) to observe the shape of the model parameters. What changes? Why?

  4. Construct a multilayer perceptron containing a shared parameter layer and train it. During the training process, observe the model parameters and gradients of each layer.

  5. Why is sharing parameters a good idea?