Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_deep-learning-computation/model-construction.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_deep-learning-computation/model-construction.ipynb .. _sec_model_construction: Layers and Blocks ================= When we first introduced neural networks, we focused on linear models with a single output. Here, the entire model consists of just a single neuron. Note that a single neuron (i) takes some set of inputs; (ii) generates a corresponding (*scalar*) output; and (iii) has a set of associated parameters that can be updated to optimize some objective function of interest. Then, once we started thinking about networks with multiple outputs, we leveraged vectorized arithmetic to characterize an entire *layer* of neurons. Just like individual neurons, layers (i) take a set of inputs, (ii) generate corresponding outputs, and (iii) are described by a set of tunable parameters. When we worked through softmax regression, a single *layer* was itself *the model*. However, even when we subsequently introduced multilayer perceptrons, we could still think of the model as retaining this same basic structure. Interestingly, for multilayer perceptrons, both the *entire model* and its *constituent layers* share this structure. The (entire) model takes in raw inputs (the features), generates outputs (the predictions), and possesses parameters (the combined parameters from all constituent layers). Likewise, each individual layer ingests inputs (supplied by the previous layer) generates outputs (the inputs to the subsequent layer), and possesses a set of tunable parameters that are updated according to the signal that flows backwards from the subsequent layer. While you might think that neurons, layers, and models give us enough abstractions to go about our business, it turns out that we often find it convenient to speak about components that are larger than an individual layer but smaller than the entire model. For example, the ResNet-152 architecture, which is wildly popular in computer vision, possesses hundreds of layers. These layers consist of repeating patterns of *groups of layers*. Implementing such a network one layer at a time can grow tedious. This concern is not just hypothetical---such design patterns are common in practice. The ResNet architecture mentioned above won the 2015 ImageNet and COCO computer vision competitions for both recognition and detection :cite:`He.Zhang.Ren.ea.2016` and remains a go-to architecture for many vision tasks. Similar architectures in which layers are arranged in various repeating patterns are now ubiquitous in other domains, including natural language processing and speech. To implement these complex networks, we introduce the concept of a neural network *block*. A block could describe a single layer, a component consisting of multiple layers, or the entire model itself! From a software standpoint, a ``Block`` is a *class*. Any subclass of ``Block`` must define a ``forward`` method that transforms its input into output and must store any necessary parameters. Note that some Blocks do not require any parameters at all! Finally a ``Block`` must possess a ``backward`` method, for purposes of calculating gradients. Fortunately, due to some behind-the-scenes magic supplied by the ``autograd`` package (introduced in :numref:`chap_preliminaries`) when defining our own ``Block``, we only need to worry about parameters and the ``forward`` function. One benefit of working with the ``Block`` abstraction is that they can be combined into larger artifacts, often recursively, (see illustration in :numref:`fig_blocks`). |Multiple layers are combined into blocks| .. _fig_blocks: By defining code to generate Blocks of arbitrary complexity on demand, we can write surprisingly compact code and still implement complex neural networks. To begin, we revisit the Blocks that we used to implement multilayer perceptrons (:numref:`sec_mlp_djl`). The following code generates a network with one fully-connected hidden layer with 256 units and ReLU activation, followed by a fully-connected *output layer* with 10 units (no activation function). .. |Multiple layers are combined into blocks| image:: https://raw.githubusercontent.com/d2l-ai/d2l-en/2885330e548958282a8dec1dca724eb0e533cfa9/img/blocks.svg .. code:: java %load ../utils/djl-imports .. code:: java NDManager manager = NDManager.newBaseManager(); int inputSize = 20; NDArray x = manager.randomUniform(0, 1, new Shape(2, inputSize)); // (2, 20) shape Model model = Model.newInstance("lin-reg"); SequentialBlock net = new SequentialBlock(); net.add(Linear.builder().setUnits(256).build()); net.add(Activation.reluBlock()); net.add(Linear.builder().setUnits(10).build()); net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT); net.initialize(manager, DataType.FLOAT32, x.getShape()); model.setBlock(net); Here we use a simple Translator for processing input and output. A ``NoopTranslator`` or No Operation Translator doesn't do any processing and simply takes in an NDList and outputs an NDList. For more complicated models, we can define our own translator and do preprocessing and postprocessing on the data. Here we pass in ``null`` for the ``Batchifier`` as we'll define classes that'll break the structure the default ``Batchifier`` expects. .. code:: java Translator translator = new NoopTranslator(); We can then pass that into a model predictor to allow inference. .. code:: java NDList xList = new NDList(x); Predictor predictor = model.newPredictor(translator); ((NDList) predictor.predict(xList)).singletonOrThrow(); .. parsed-literal:: :class: output ND: (2, 10) gpu(0) float32 [[ 0.0042, -0.0039, -0.0049, -0.0034, -0.0049, -0.0034, -0.0012, -0.005 , -0.0002, -0.0035], [ 0.0042, -0.0028, -0.0028, -0.0021, -0.0031, -0.0013, -0.0006, -0.003 , 0.0002, -0.0045], ] Notice that we have to wrap ``x`` in an NDList since it is an NDArray before passing it in. Since ``predict()`` returns an Object, we also have to cast the return value of ``predict()`` to an NDList(we know that it should be an NDList from NoopTranslator) to call its ``singletonOrThrow()`` function. In this example, we constructed our model by instantiating an ``SequentialBlock``, assigning the returned object to the ``net`` variable. Next, we repeatedly call its ``add()`` method, appending layers in the order that they should be executed. In short, ``SequentialBlock`` defines a special kind of ``AbstractBlock`` that maintains an ordered list of constituent ``AbstractBlocks``. The ``add()`` method simply facilitates the addition of each successive ``AbstractBlock`` to the list. Note that each layer is an instance of the ``Linear`` class which is itself a subclass of ``AbstractBlock``. The ``forward()`` function is also remarkably simple: it chains each Block in the list together, passing the output of each as the input to the next. A Custom Block -------------- Perhaps the easiest way to develop intuition about how ``Block`` works is to implement one ourselves. Before we implement our own custom ``Block``, we briefly summarize the basic functionality that each ``Block`` must provide: 1. Ingest input data as arguments to its ``forward()`` method. 2. Generate an output by having ``forward()`` return a value. Note that the output may have a different shape from the input. For example, the first Dense layer in our model above ingests an input of arbitrary dimension but returns an output of dimension 256. 3. Calculate the gradient of its output with respect to its input, which can be accessed via its ``backward()`` method. Typically this happens automatically. 4. Store and provide access to those parameters necessary to execute the ``forward()`` computation. 5. Initialize these parameters as needed. In the following snippet, we code up a Block from scratch corresponding to a multilayer perceptron with one hidden layer with 256 hidden nodes, and a 10-dimensional output layer. Note that the ``MLP`` class below inherits the ``AbstractBlock`` class. We will rely heavily on the parent class's methods, as well as implement its required to be overriden methods. .. code:: java class MLP extends AbstractBlock { private static final byte VERSION = 1; private Block flattenInput; private Block hidden256; private Block output10; // Declare a layer with model parameters. Here, we declare two fully // connected layers public MLP(int inputSize) { super(VERSION); // Dont need to worry about this flattenInput = addChildBlock("flattenInput", Blocks.batchFlattenBlock(inputSize)); hidden256 = addChildBlock("hidden256", Linear.builder().setUnits(256).build());// Hidden Layer output10 = addChildBlock("output10", Linear.builder().setUnits(10).build()); // Output Layer } @Override // Define the forward computation of the model, that is, how to return // the required model output based on the input x protected NDList forwardInternal( ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { NDList current = inputs; current = flattenInput.forward(parameterStore, current, training); current = hidden256.forward(parameterStore, current, training); // We use the Activation.relu() function here // Since it takes in an NDArray, we call `singletonOrThrow()` // on the NDList `current` to get the NDArray and then // wrap it in a new NDList to be passed // to the next `forward()` call current = new NDList(Activation.relu(current.singletonOrThrow())); current = output10.forward(parameterStore, current, training); return current; } @Override public Shape[] getOutputShapes(Shape[] inputs) { Shape[] current = inputs; for (Block block : children.values()) { current = block.getOutputShapes(current); } return current; } @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { hidden256.initialize(manager, dataType, new Shape(1, inputSize)); output10.initialize(manager, dataType, new Shape(1, 256)); } } To begin, let us focus on the ``forward()`` method. Note that it takes a ``ParameterStore``, ``NDList`` input, ``boolean`` training, and ``PairList<>`` params as input, but for now you only need to care about the ``NDList`` input. It then passes the inputs through each layer In this MLP implementation, both layers are instance variables. To see why this is reasonable, imagine instantiating two MLPs, ``net1`` and ``net2``, and training them on different data. Naturally, we would expect them to represent two different learned models. We instantiate the MLP's layers in the ``initializeChildBlocks()`` method and subsequently invoke these layers on each call to the ``forward()`` method. Note a few key details. First, our customized ``initializeChildBlocks()`` method invokes each child class's ``initialize()`` method, sparing us the pain of restating boilerplate code applicable to most Blocks. We then instantiate our two ``Linear`` layers, adding them to ``hidden256`` and ``output10``. Note that unless we implement a new operator, we need not worry about backpropagation (the ``backward`` method) or parameter initialization (the ``initialize`` method). DJL will generate these methods automatically. We also don't have to call ``initializeChildBlocks()`` and instead simply call ``AbstractBlock``'s ``initialize()`` method as ``AbstractBlock`` automatically calls ``initializeChildBlocks()`` in it along with a few other things we don't need to worry about for now. Let us try this out: .. code:: java MLP net = new MLP(inputSize); net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT); net.initialize(manager, DataType.FLOAT32, x.getShape()); model.setBlock(net); Predictor predictor = model.newPredictor(translator); ((NDList) predictor.predict(xList)).singletonOrThrow(); .. parsed-literal:: :class: output ND: (2, 10) gpu(0) float32 [[ 0.001 , -0.0002, 0.0031, 0.0045, 0.003 , 0.0038, 0.0037, 0.0045, 0.0009, 0.002 ], [ 0.0018, -0.0006, 0.0028, 0.0001, 0.0041, 0.0021, 0.0013, -0.0008, 0.0011, 0.0026], ] A key virtue of the ``Block`` abstraction is its versatility. We can subclass ``AbstractBlock`` to create layers (such as the ``Linear`` block provided by DJL), entire models (such as the ``MLP`` above), or various components of intermediate complexity. We exploit this versatility throughout the following chapters, especially when addressing convolutional neural networks. The Sequential Block -------------------- We can now take a closer look at how the ``SequentialBlock`` class works. Recall that ``SequentialBlock`` was designed to daisy-chain other Blocks together. To build our own simplified ``MySequential``, we just need to define two key methods: 1. An ``add()`` method for appending Blocks one by one to a list. 2. A ``forward()`` method to pass an input through the chain of Blocks (in the same order as they were appended). Additional helper methods we need to define are: 1. An ``initializeChildBlocks()`` method for child block initialization. 2. A ``getOutputShapes()`` method for return the output shape. The following ``MySequential`` class delivers the same functionality as DJL's default ``SequentialBlock`` class: .. code:: java class MySequential extends AbstractBlock { private static final byte VERSION = 2; public MySequential() { super(VERSION); } public MySequential add(Block block) { // Here, block is an instance of a Block subclass, and we assume it has // a unique name. We add the child block to the children BlockList // with `addChildBlock()` which is defined in AbstractBlock. if (block != null) { addChildBlock(block.getClass().getSimpleName(), block); } return this; } @Override protected NDList forwardInternal( ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { NDList current = inputs; for (Block block : children.values()) { // BlockList guarantees that members will be traversed in the order // they were added current = block.forward(parameterStore, current, training); } return current; } @Override // Initializes all child blocks public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { Shape[] shapes = inputShapes; for (Block child : getChildren().values()) { child.initialize(manager, dataType, shapes); shapes = child.getOutputShapes(shapes); } } @Override public Shape[] getOutputShapes(Shape[] inputs) { return inputs; } } The ``add()`` method adds a single Block to the block list ``children`` by using the ``addChildBlock()`` method implemented in ``AbstractBlock``. You might wonder why every DJL ``Block`` possesses a ``children`` attribute and why we used it rather than just defining a Java list ourselves. In short the chief advantage of ``children`` is that during our Block's parameter inititialization, DJL knows to look in the ``children`` list to find sub-Blocks whose parameters also need to be initialized. When our ``MySequential`` Block's ``forward()`` method is invoked, each added ``Block`` is executed in the order in which they were added. We can now reimplement an MLP using our ``MySequential`` class. .. code:: java MySequential net = new MySequential(); net.add(Linear.builder().setUnits(256).build()); net.add(Activation.reluBlock()); net.add(Linear.builder().setUnits(10).build()); net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT); net.initialize(manager, DataType.FLOAT32, x.getShape()); Model model = Model.newInstance("my-sequential"); model.setBlock(net); Predictor predictor = model.newPredictor(translator); ((NDList) predictor.predict(xList)).singletonOrThrow(); .. parsed-literal:: :class: output ND: (2, 10) gpu(0) float32 [[ 0.0054, -0.0037, 0.0034, 0.0021, -0.0048, 0.0027, -0.0022, 0.0045, 0.0014, 0.0005], [ 0.0053, -0.0016, 0.004 , -0.0018, -0.0056, 0.0016, -0.003 , 0.0026, 0.0012, -0.0009], ] Note that this use of ``MySequential`` is identical to the code we previously wrote for the DJL ``SequentialBlock`` class (as described in :numref:`sec_mlp_djl`). Executing Code in the ``forward`` Method ---------------------------------------- The ``SequentialBlock`` class makes model construction easy, allowing us to assemble new architectures without having to define our own class. However, not all architectures are simple daisy chains. When greater flexibility is required, we will want to define our own ``Block``\ s. For example, we might want to execute Java's control flow within the forward method. Moreover we might want to perform arbitrary mathematical operations, not simply relying on predefined neural network layers. You might have noticed that until now, all of the operations in our networks have acted upon our network's activations and its parameters. Sometimes, however, we might want to incorporate terms that are neither the result of previous layers nor updatable parameters. In DJL, we call these *constant* parameters. Say for example that we want a layer that calculates the function :math:`f(\mathbf{x},\mathbf{w}) = c \cdot \mathbf{w}^\top \mathbf{x}`, where :math:`\mathbf{x}` is the input, :math:`\mathbf{w}` is our parameter, and :math:`c` is some specified constant that is not updated during optimization. In the following code, we will implement a model that could not easily be assembled using only predefined layers and ``SequentialBlock``. .. code:: java class FixedHiddenMLP extends AbstractBlock { private static final byte VERSION = 1; private Block hidden20; private NDArray constantParamWeight; private NDArray constantParamBias; public FixedHiddenMLP() { super(VERSION); hidden20 = addChildBlock("denseLayer", Linear.builder().setUnits(20).build()); } @Override protected NDList forwardInternal( ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { NDList current = inputs; // Fully connected layer current = hidden20.forward(parameterStore, current, training); // Use the constant parameters NDArray // Call the NDArray internal method `linear()` to do calculation current = Linear.linear(current.singletonOrThrow(), constantParamWeight, constantParamBias); // Relu Activation current = new NDList(Activation.relu(current.singletonOrThrow())); // Reuse the fully connected layer. This is equivalent to sharing // parameters with two fully connected layers current = hidden20.forward(parameterStore, current, training); // Here in Control flow, we return the scalar // for comparison while (current.head().abs().sum().getFloat() > 1) { current.head().divi(2); } return new NDList(current.head().abs().sum()); } @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { Shape[] shapes = inputShapes; for (Block child : getChildren().values()) { child.initialize(manager, dataType, shapes); shapes = child.getOutputShapes(shapes); } // Initialize constant parameter layer constantParamWeight = manager.randomUniform(-0.07f, 0.07f, new Shape(20, 20)); constantParamBias = manager.zeros(new Shape(20)); } @Override public Shape[] getOutputShapes(Shape[] inputs) { return new Shape[]{new Shape(1)}; // we return a scalar so the shape is 1 } } In this ``FixedHiddenMLP`` model, we implement a hidden layer whose weights are initialized randomly at instantiation and are thereafter constant. This weight is not a model parameter and thus it is never updated by backpropagation. The network then passes the output of this *fixed* layer through a ``Linear`` layer. Note that before returning output, our model did something unusual. We ran a ``while`` loop, testing on the condition ``np.abs(x).sum() > 1``, and dividing our output vector by :math:`2` until it satisfied the condition. Finally, we returned the sum of the entries in ``x``. To our knowledge, no standard neural network performs this operation. Note that this particular operation may not be useful in any real world task. Our point is only to show you how to integrate arbitrary code into the flow of your neural network computations. .. code:: java xList .. parsed-literal:: :class: output NDList size: 1 0 : (2, 20) float32 .. code:: java FixedHiddenMLP net = new FixedHiddenMLP(); net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT); net.initialize(manager, DataType.FLOAT32, x.getShape()); Model model = Model.newInstance("fixed-mlp"); model.setBlock(net); Predictor predictor = model.newPredictor(translator); ((NDList) predictor.predict(xList)).singletonOrThrow(); .. parsed-literal:: :class: output ND: () gpu(0) float32 0.006 With DJL, we can mix and match various ways of assembling ``Block``\ s together. In the following example, we nest ``Block``\ s in some creative ways. .. code:: java class NestMLP extends AbstractBlock { private SequentialBlock net; private Block dense; private Block test; public NestMLP() { net = new SequentialBlock(); net.add(Linear.builder().setUnits(64).build()); net.add(Activation.reluBlock()); net.add(Linear.builder().setUnits(32).build()); net.add(Activation.reluBlock()); addChildBlock("net", net); dense = addChildBlock("dense", Linear.builder().setUnits(16).build()); } @Override protected NDList forwardInternal( ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { NDList current = inputs; // Fully connected layer current = net.forward(parameterStore, current, training); current = dense.forward(parameterStore, current, training); current = new NDList(Activation.relu(current.singletonOrThrow())); return current; } @Override public Shape[] getOutputShapes(Shape[] inputs) { Shape[] current = inputs; for (Block block : children.values()) { current = block.getOutputShapes(current); } return current; } @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { Shape[] shapes = inputShapes; for (Block child : getChildren().values()) { child.initialize(manager, dataType, shapes); shapes = child.getOutputShapes(shapes); } } } SequentialBlock chimera = new SequentialBlock(); chimera.add(new NestMLP()); chimera.add(Linear.builder().setUnits(20).build()); chimera.add(new FixedHiddenMLP()); chimera.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT); chimera.initialize(manager, DataType.FLOAT32, x.getShape()); Model model = Model.newInstance("chimera"); model.setBlock(chimera); Predictor predictor = model.newPredictor(translator); ((NDList) predictor.predict(xList)).singletonOrThrow(); .. parsed-literal:: :class: output ND: () gpu(0) float32 1.28119018e-08 Summary ------- - Layers are Blocks. - A Block can contain many layers. - A Block can contain many Blocks. - A Block can contain code. - Blocks take care of lots of housekeeping, including parameter initialization and backpropagation. - Sequential concatenations of layers and blocks are handled by the ``SequentialBlock`` Block. Exercises --------- 1. Implement a block that takes two blocks as an argument, say ``net1`` and ``net2`` and returns the concatenated output of both networks in the forward pass (this is also called a parallel block). 2. Assume that you want to concatenate multiple instances of the same network. Implement a factory function that generates multiple instances of the same block and build a larger network from it.