Category Archives: Deep Learning

Converting your Keras model into a spiking neural network

Let’s set the scene: You know Keras, you’ve heard about spiking neural networks (SNNs), and you want to see what all the fuss is about. For some reason. Maybe it’s to take advantage of some cool neuromorphic edge AI hardware, maybe you’re into computational modeling of the brain, maybe you’re masochistic and like a challenge, or you just think SNNs are cool. I won’t question your motives as long as you don’t start making jokes about SkyNet.

Welcome! In this post I’m going to walk through using Nengo DL to convert models built using Keras into SNNs. Nengo is neural modeling and runtime software built and maintained by Applied Brain Research. We started it and have been using it in the Computational Neuroscience Research Group for a long time now. Nengo DL lets you build neural networks using the Nengo API, and then run them using TensorFlow. You can run any kind of network you want in Nengo (ANNs, RNNs, CNNs, SNNs, etc), but here I’ll be focusing on SNNs.

There are a lot of little things to watch out for in this process. In this post I’ll work through the steps to convert and debug a simple network that classifies MNIST digits. The goal is to show you how you can start converting your own networks and some ways you can debug it if you encounter issues. Programming networks with temporal dynamics is a pretty unintuitive process and there’s lots of nuance to learn, but hopefully this will help you get started.

You can find an IPython notebook with all the code you need to run everything here up on my GitHub. The code I’ll be showing in this post is incomplete. I’m going to focus on the building, training, and conversion of the network and leave out parts like imports, loading in MNIST data, etc. To actually run this code you should get the IPython notebook, and make sure you have the latest Nengo, Nengo DL, and all the other dependencies installed.

Build your network in Keras and running it using NengoDL

The network is just going to be a convnet layer and then a fully connected layer. We build this in Keras all per uje … ushe … you-j … usual:

input = tf.keras.Input(shape=(28, 28, 1))
conv1 = tf.keras.layers.Conv2D(
flatten = tf.keras.layers.Flatten()(conv1)
dense1 = tf.keras.layers.Dense(units=10)(flatten)

model = tf.keras.Model(inputs=input, outputs=dense1)

Once the model is made we can generate a Nengo network from this by calling the NengoDL Converter. We pass the converted network into the Simulator, compile with the standard one-hot classification loss function, and start training.

converter = nengo_dl.Converter(
    swap_activations={tf.nn.relu: nengo.RectifiedLinear()},
net =
nengo_input = converter.inputs[input]
nengo_output = converter.outputs[dense1]

# run training
with nengo_dl.Simulator(net, seed=0) as sim:
        loss={nengo_output: tf.losses.SparseCategoricalCrossentropy(from_logits=True)},
    ), {nengo_output: train_labels}, epochs=10)

    # save the parameters to file

In the Converter call you’ll see that there’s a swap_activations keyword. This is for us to map the TensorFlow activation functions to Nengo activation functions. In this case we’re just mapping ReLU to ReLU. After the training is done, we save the trained parameters to file.

Next, we can load our trained parameters, call the sim.predict function and plot the results:

n_test = 5
with nengo_dl.Simulator(net, seed=0) as sim:
    data = sim.predict({nengo_input: test_images[:n_test]})

# plot the answer in a big red x
plt.plot(test_labels[:n_test].squeeze(), 'rx', mew=15)
# data[nengo_output].shape = (images, timesteps, n_outputs)
# plot predicted digit from network output on the last time step
plt.plot(np.argmax(data[nengo_output][:, -1], axis=1), 'o', mew=2)
Red x’s are the answers, blue o’s are the network predictions.

NengoDL inherently accounts for time in its simulations and so the data all needs to be formatted as (n_batches, n_timesteps, n_inputs). In this case everything is using standard rate mode neurons with no internal states that change over time, so simulating the network over time will generate the same output at every time step. When we move to spiking neurons, however, effects of temporal dynamics will be visible.

Convert to spiking neurons

To convert our network to spiking neurons, all we have to do is change the neuron activation function that we map to in the converter when we’re generating the net:

converter = nengo_dl.Converter(
    swap_activations={tf.nn.relu: nengo.SpikingRectifiedLinear()},
net =
nengo_input = converter.inputs[input]
nengo_output = converter.outputs[dense1]

with nengo_dl.Simulator(net) as sim:
    data = sim.predict({nengo_input: test_images[:n_test]})

So the process is that we create the model using Keras once. Then we can use a NengoDL Converter to create a Nengo network that can be simulated and trained. We save the parameters after training, and now we can use another Converter to create another instance of the network that uses SpikingRectifiedLinear neurons as the activation function. We can then load in the trained parameters that we got from simulating the standard RectifiedLinear rate mode activation function.

The reason that we don’t do the training with the SpikingRectifiedLinear activation function is because of its discontinuities, which cause errors when trying to calculate the derivative in backprop.

So! What kind of results do we get now that we’ve converted into spiking neurons?

Red x’s are the answers, blue o’s are the network predictions.

Not great! Why is this happening? Good question. It seems like what’s happening is the network is really convinced everything is 5. To investigate, we’re going to need to look at the neural activity.

Plotting the neural activity over time

To be able to see the activity of the neurons, we’re going to need 1) a reference to the ensemble of neurons we’re interested in monitoring, and 2) a Probe to track the activity of those neurons during simulation. This involves changing the network, so we’ll create another converted network and modify it before passing it into our Simulator.

converter = nengo_dl.Converter(
    swap_activations={tf.nn.relu: nengo.SpikingRectifiedLinear()},
net =

nengo_input = converter.inputs[input]
nengo_output = converter.outputs[dense1]
# get a reference for the neurons that we want to probe
nengo_conv1 = converter.layers[conv1]

# add probe to the network to track the activity of those neurons!
with as net:
    probe_conv1 = nengo.Probe(nengo_conv1, label='probe_conv1')

with nengo_dl.Simulator(net) as sim:
    data = sim.predict({nengo_input: test_images[:n_test]})

We can use Nengo’s handy rasterplot helper function to plot the activity of the first 3000 neurons:

from nengo.utils.matplotlib import rasterplot
# plot results neural activity from the first n_neurons on the
# first batch (each image is a batch), all time steps
rasterplot(np.arange(n_timesteps), data[probe_conv1][0, :, :n_neurons])
Raster plot of the spiking neural activity

If you have a keen eye and familiar with raster plots, you may notice that there are no spikes. Not a single one! So our network isn’t predicting a 5 for each input, it’s not predicting anything. We’re just getting 5 as output from a learned bias term. Bummer.

Let’s go back and see what kind of output our rate mode neurons are giving us, maybe that can help explain what’s going on. We can’t use a raster plot, because there are no spikes, but we can use a regular plot.

converter = nengo_dl.Converter(
    swap_activations={tf.nn.relu: nengo.RectifiedLinear()},
net =

nengo_input = converter.inputs[input]
nengo_output = converter.outputs[dense1]
nengo_conv1 = converter.layers[conv1]

with net:
    probe_conv1 = nengo.Probe(nengo_conv1)

with nengo_dl.Simulator(net) as sim:
    data = sim.predict({nengo_input: test_images[:n_test]})

n_neurons = 5000
print('Max value: ', np.max(data[probe_conv1].flatten()))
# plot activity of first 5000 neurons, all inputs, all time steps
# we reshape the data so it's (n_batches * n_timesteps, n_neurons)
# for ease of plotting
plt.plot(data[probe_conv1][:, :, :n_neurons].reshape(-1, n_neurons))
Max value: 34.60704

Looking at the rate mode activity of the neurons, we see that the max firing rate is 34.6Hz. That’s about 1 spike every 30 milliseconds. Through an unlucky coincidence we’ve set each image to be presented to the network for 30ms. What could be happening is that neurons are building up to the point of spiking, but the input is switched before they actually spike. To test this, let’s change the number of time steps each image is presented to 100ms, and rerun our spiking simulator.

We’ll also switch over to plotting spiking activity the same way as the rate neurons for consistency (and because it’s easier to see if neurons are spiking or not when there’s really sparse activity). One thing you may note below is that I’m plotting activity * dt, instead of just activity like in the rate neuron case. Whenever there’s a spike in Nengo, it’s recorded as 1/dt so that it integrates to 1. Multiplying the probe output by dt means that we see a 1 when there’s one spike per time step, a 2 when there’s two spikes per time step, etc. It just makes it a bit easier to read.

print('Max value: ', np.max(data[probe_conv1].flatten() * dt))
plt.plot(data[probe_conv1][:, :, :n_neurons].reshape(-1, n_neurons) * dt)

Looking at this plot, we can see a few things. First, whenever an image is presented, there’s a startup period where no spikes occur. Ideally our network will give us some output without having to present input for 50 time steps (the rate network gives us output after 1 time step!) We can address this by increasing the firing rates of the neurons in our network. We’ll come back to this.

Second, even now that we’re getting spikes, the predictions for each image are very poor. Why is that happening? Let’s look at the network output over time when we feed in the first test image:

Output from the spiking network over 100ms (being shown a single image)

From this plot it looks like there’s not a clear prediction so much as a bunch of noise. One factor that can contribute to this is that the way things are set up right now, when a neuron spikes that information is only processed by the receiving side for 1 time step. Anthropomorphizing our network, you can think of the output layer as receiving input along the lines of “nothing … nothing … nothing … IT’S A 5! … maybe a 2? … maybe a 3? … IT’S A 5! ..” etc.

It would be useful for us to do a bit of averaging over time. Enter: synapses!

Using synapses to smooth the output of a spiking network

Synapses can come in many forms. The default form in Nengo is a low-pass filter. What this does for us is let the post-synaptic neuron (i.e. the neuron that we’re sending information to) do a bit of integration of the information that is being sent to it. So in the above example the output layer would be receiving input like “nothing … nothing … nothing … looking like a 5 … IT’S A 5! … it’s a 5! … it’s probably a 5 but maybe also a 2 or 3 … IT’S A 5! … ” etc.

Likely it will be more useful for understanding to see the actual network output over time with different low-pass filters applied rather than reading strained metaphors.

To make apply a low-pass filter synapse to all of the connections in our network is easy enough, we just add another modification to the network before passing it into the Simulator:

converter = nengo_dl.Converter(
    swap_activations={tf.nn.relu: nengo.SpikingRectifiedLinear()},
net =

nengo_input = converter.inputs[input]
nengo_output = converter.outputs[dense1]
nengo_conv1 = converter.layers[conv1]

with net:
    probe_conv1 = nengo.Probe(nengo_conv1)

# set a low-pass filter value on all synapses in the network
for conn in net.all_connections:
    conn.synapse = 0.001

with nengo_dl.Simulator(net) as sim:
    data = sim.predict({nengo_input: test_images[:n_test]})

And here are the results we get for different low-pass filter time constants:

As the time constant on the low-pass filter increases, we can see the output of the network starts smoothing out. It’s important to recognize that there are a few things going on though when we filter values on all of the synapses of the network. The first is that we’re no longer sending sharp spikes between layers, we’re now passing along filtered spikes. The larger the time constant on the filter, the more spread out and smoother the signal will be.

As a result of this: If sending a spike from neuron A to neuron B used to cause neuron B to spike immediately when there was no synaptic filter, that may no longer be the case. It may now take several spikes from neuron A, all close together in time, for neuron B to now spike.

Another thing we want to consider is that we’ve applied a synaptic filter at every layer, so the dynamics of the entire network have changed. Very often you’ll want to be more surgical with the synapses you create, leaving some connections with no synapse and some with large filters to get the best performance out of your network. Currently the way to do this is to print out net.all_connections, find the connections of interest, and then index in specific values. When we print out net.all_connections for this network, we get:

[<Connection at 0x7fd8ac65e4a8 from <Node "conv2d.0.bias"> to <Node (unlabeled) at 0x7fd8ac65ee48>>,
<Connection at 0x7fd969a06b70 from <Node (unlabeled) at 0x7fd8ac65ee48> to <Neurons of <Ensemble "conv2d.0">>>,
<Connection at 0x7fd969a06390 from <Node "input_1"> to <Neurons of <Ensemble "conv2d.0">>>,
<Connection at 0x7fd8afc41470 from <Node "dense.0.bias"> to <Node "dense.0">>,
<Connection at 0x7fd8afc41588 from <Neurons of <Ensemble "conv2d.0">> to <Node "dense.0">>]

The connections of interest for us are from input_1 to conv2d.0, and from conv2d.0 to dense.0. These are the connections the input signals for the network are flowing through, the rest of the connections are just to send in trained bias values to each layer. We can set the synapse value for these connections specifically with the following:

synapses = [None, None, 0.001, None, 0.001]
for conn, synapse in zip(net.all_connections, synapses):
    conn.synapse = synapse

In this case, with just some playing around with different values I wasn’t able to find any synapse values that got better performance than 4/5. But in general being able to set specific synapse values in a spiking neural network is important and you should be aware of how to do it to get the best performance out of your network.

So setting the synapses is able to improve the performance of the network. We’re still taking 100ms to generate output though, and only getting 4/5 of the test set correct. Let’s go back now to the first issue we identified and look at increasing the firing rates of the neurons.

Increasing the firing rates of neurons in the network

There are a few ways to go about this. The first is a somewhat cheeky method that works best for rectified linear (ReLU) neurons, and the second is a more general method that adjusts how training is performed.

Scaling ReLU firing rates

Because we’re using rectified linear neurons in this model, one trick that we can use to increase the firing rates without affecting the functionality of the network is by using a scaling term to multiply the input and divide the output of each neuron. This scaling can work because the rectified linear neuron model is linear in its activation function.

The result of this scaling is more frequent spiking at a lower amplitude. We can implement this using the Nengo DL Converter with the scale_firing_rates keyword:

converter = nengo_dl.Converter(
    swap_activations={tf.nn.relu: nengo.SpikingRectifiedLinear()},

Let’s look at the network output and neural activity plots for gain_scale values of [5, 20, 50].

One thing that’s apparent is as the firing rates go up, the performance of the network gets better. You may notice that for the first image (first 30 time steps) in the spiking activity plots there’s no response. Don’t read too much into that; I’m plotting a random subset of neurons and it just happens that none of them respond to the first image. If I was to plot the activity of all of the neurons we’d see spikes everywhere.

You may also notice that when gain_scale = 50 we’re even getting some neurons that are spiking 2 times per time step. That will happen when the input to the neuron causes the internal state to jump up to twice the threshold for spiking for that neuron. This is not unexpected behaviour.

Using this scale_firing_rates keyword in the Converter is one way to get the performance of our coverted spiking networks to match the performance of rate neuron networks. However, it mainly a trick useful for us for ReLUs (and any other linear activation functions). It would behoove us to figure out another method that will work as well for nonlinear activation functions as well.

Adding a firing rate term to the cost function during training

Let’s go back to the training the network step. Another way to bring the firing rates up is by adding a term to the cost function that will penalize any firing rates outside of some desired range. There are a ton of ways to go about this with different kinds of cost functions. I’m just going to present one cost function term that works for this situation and note that you can build this cost function a whole bunch of different ways. Here’s one:

def put_in_range(x, y, weight=100.0, min=200, max=300):
    index_greater = (y > max)  # find neurons firing faster
    index_lesser = (y < min)  # find neurons firing slower
    error = tf.reduce_sum(y[index_greater] - max) + tf.reduce_sum(min - y[index_lesser])
    return weight * error

The weight parameter lets us set the relative importance of the firing rates cost function term relative to the classification accuracy cost function term. To use this term we need to make a couple of adjustments to our code for training the network:

converter = nengo_dl.Converter(
    swap_activations={tf.nn.relu: nengo.RectifiedLinear()},
net =

nengo_input = converter.inputs[input]
nengo_output = converter.outputs[dense1]

nengo_conv1 = converter.layers[conv1]
with as net:
    probe_conv1 = nengo.Probe(nengo_conv1, label='probe_conv1')

# run training
with nengo_dl.Simulator(net) as sim:
            nengo_output: tf.losses.SparseCategoricalCrossentropy(from_logits=True),
            probe_conv1: put_in_range,
        {nengo_output: train_labels, 
         probe_conv1: np.zeros(train_labels.shape)},

Mainly what’s been added in this code is our new loss function put_in_range and in the sim.compile call we added probe_conv1: put_in_range to the loss dictionary. This tells Nengo DL to use the put_in_range cost function on the output from probe_conv1, which will be the firing rates of the convolutional layer of neurons in the network.

We also had to add in probe_conv1: np.zeros(train_labels.shape) to the input dictionary in the function call. The array specified here is used as the x input to the put_in_range cost function, but since we defined the put_in_range function to be fully determined based only on y (which is the output from probe_conv1) it doesn’t matter what values we pass in there. So I pass in an array of zeros.

Now when we run the training and prediction in rate mode, the output we get looks like

And we can see that we’re still getting the same performance, but now the firing rates of the neurons are much higher. Let’s see what happens when we convert to spiking neurons now!

Hey, that’s looking much better! This of course is only looking at 5 test images and you’ll want to go through and calculate proper performance statistics using a full test set, but it’s a good start.


This post has looked at how to take a model that you built in Keras and convert it over to a spiking neural network using Nengo DL’s Converter function. This was a simple model, but hopefully it gets across that the conversion to spikes can be an iterative process, and you now have a better sense of some of the steps that you can take to investigate and debug spiking neural network behaviour! In general when tuning your network you’ll use a mix of the different methods we’ve gone through here, depending on the exact situation.

Again a reminder that all of the code for this can be found up on my GitHub.

Also! It’s very much worth checking out the Nengo DL documentation and other examples that they have there. There’s a great introduction for users coming from TensorFlow over to Nengo, and other examples showing how you can integrate non-spiking networks with spiking networks, as well as other ways to optimizing your spiking neural networks. If you start playing around with Nengo and have more questions, please feel free to ask in comments below or even better go to the Nengo forums!

Tagged , , , ,

Natural policy gradient in TensorFlow

In working towards reproducing some results from deep learning control papers, one of the learning algorithms that came up was natural policy gradient. The basic idea of natural policy gradient is to use the curvature information of the of the policy’s distribution over actions in the weight update. There are good resources that go into details about the natural policy gradient method (such as here and here), so I’m not going to go into details in this post. Rather, possibly just because I’m new to TensorFlow, I found implementing this method less than straightforward due to a number of unexpected gotchyas. So I’m going to quickly review the algorithm and then mostly focus on the implementation issues that I ran into.

If you’re not familiar with policy gradient or various approaches to the cart-pole problem I also suggest that you check out this blog post from KV Frans, which provides the basis for the code I use here. All of the code that I use in this post is available up on my GitHub, and it’s worth noting this code is for TensorFlow 1.7.0.

Natural policy gradient review

Let’s informally define our notation

  • s_t and a_t are the system state and chosen action at time t
  • \theta are the network parameters that we’re learning
  • \pi_\theta is the control policy network parameterized \theta
  • A^\pi(s_t, a_t, t) is the advantage function, which returns the estimated total reward for taking action a_t in state s_t relative to default behaviour (i.e. was this action better or worse than what we normally do)

The regular policy gradient is calculated as:

\textbf{g} = \frac{1}{T}\sum_{t=0}^T \nabla_\theta \; \textrm{log} \; \pi_\theta(a_t | s_t) \; \hat{A}^\pi(s_t, a_t, t),

where \nabla_\theta denotes calculating the partial derivative with respect to \theta, the hat above A^\pi denotes that we’re using an estimation of the advantage function, and \pi_\theta(a_t | s_t) returns a scalar value which is the probability of the chosen action a_t in state s_t.

For the natural policy gradient, we’re going to calculate the Fisher information matrix, F, which estimates the curvature of the control policy:

\hat{F}_{\theta_k} = \frac{1}{T} \sum_{t=0}^T \nabla_\theta \; \textrm{log} \; \pi_\theta(a_t | s_t) \; \nabla_\theta \; \textrm{log} \; \pi_\theta (a_t | s_t)^T.

The goal of including curvature information is to be able move along the steepest ascent direction, which we estimate with \hat{\textbf{F}}^{-1}_{\theta_k}. Including this then the weight update becomes

\theta_{k+1} = \theta_k + \alpha \hat{\textbf{F}}^{-1}_{\theta_k} \textbf{g},

where \alpha is a learning rate parameter. In his paper Towards Generalization and Simplicity in Continuous Control, Aravind Rajeswaran notes that empirically it’s difficult to choose an appropriate \alpha value or set an appropriate schedule. To address this, they normalize the update under the Fisher metric:

\theta_{k+1} = \theta_k + \sqrt{\frac{\delta}{\textbf{g}^T\hat{\textbf{F}}^{-1}_{\theta_k}\textbf{f}}} \; \hat{\textbf{F}}^{-1}_{\theta_k} \textbf{g}.

Using the above for the learning rate can be thought of as choosing a step size \delta that is normalized with respect to the change in the control policy. This is beneficial because it means that we don’t do any parameter updates that will drastically change the output of the policy network.

Here’s what the pseudocode for the algorithm looks like

  • Initialize policy parameters \theta_0
  • for k = 1 to K do:
    • collect N trajectories by rolling out the stochastic policy \pi_{\theta_k}
    • compute \nabla_\theta \; \textrm{log} \; \pi_\theta(a_t | s_t) for each (s, a) pair along the trajectories sampled
    • compute advantages \hat{A}^\pi_k based on the sampled trajectories and the estimated value function V^\pi_{k-1}
    • compute the policy gradient g as specified above
    • compute the Fisher matrix and perform the weight update as specified above
    • update the value network to get V^\pi_k

OK great! On to implementation.

Implementation in TensorFlow 1.7.0

Calculating the gradient at every time step

In TensorFlow, the tf.gradient(ys, xs) function, which calculates the partial derivative of  ys with respect to xs:

\nabla_{X} Y

will automatically sum over over all elements in ys. There is no function for getting the gradient of each individual entry (i.e. the Jacobian) in ys, which is less than great. People have complained, but until some update happens we can address this with basic list comprehension to iterate through each entry in the history and calculate the gradient at that point in time. If you have a faster means that you’ve tested please comment below!

The code for this looks like

        g_log_prob = tf.stack(
            [tf.gradients(action_log_prob_flat[i], my_variables)[0]
                for i in range(action_log_prob_flat.get_shape()[0])])

Where the result is called g_log_prob_flat because it’s the gradient of the log probability of the chosen action at each time step.

Inverse of a positive definite matrix with rounding errors

One of the problems I ran into was in calculating the normalized learning rate \sqrt{\frac{\delta}{\textbf{g}^T\hat{\textbf{F}}^{-1}_{\theta_k}\textbf{f}}} was that the square root calculation was dying because negative values were being set in. This was very confusing because \hat{\textbf{F}} is a positive definite matrix, which means that \hat{\textbf{F}}^{-1} is also positive definite, which means that \textbf{x}^T \hat{\textbf{F}}^{-1} \textbf{x} \geq 0 for any \textbf{x}.

So it was clear that it was a calculation error somewhere, but I couldn’t find the source of the error for a while, until I started doing the inverse calculation myself using SVD to explicitly look at the eigenvalues and make sure they were all > 0. Turns out that there were some very, very small eigenvalues with a negative sign, like -1e-7 kind of range. My guess (and hope) is that these are just rounding errors, but they’re enough to mess up the calculations of the normalized learning rate. So, to handle this I explicitly set those values to zero, and that took care of that.

        S, U, V = tf.svd(F)
        atol = tf.reduce_max(S) * 1e-6
        S_inv = tf.divide(1.0, S)
        S_inv = tf.where(S < atol, tf.zeros_like(S), S_inv)
        S_inv = tf.diag(S_inv)
        F_inv = tf.matmul(S_inv, tf.transpose(U))
        F_inv = tf.matmul(V, F_inv)

Manually adding an update to the learning parameters

Once the parameter update is explicitly calculated, we then need to update them. To implement an explicit parameter update in TensorFlow we do

        # update trainable parameters
        my_variables[0] = tf.assign_add(my_variables[0], update)

So to update the parameters then we call, feed_dict={...}). It’s worth noting here too that any time you run the session with my_variables the parameter update will be calculated and applied, so you can’t run the session with my_variables to only fetch the current values.

Results on the CartPole problem

Now running the original policy gradient algorithm against the natural policy gradient algorithm (with everything else the same) we can examine the results of using the Fisher information matrix in the update provides some strong benefits. The plot below shows the maximum reward received in a batch of 200 time steps, where the system receives a reward of 1 for every time step that the pole stays upright, and 200 is the maximum reward achievable.
To generate this plot I ran 10 sessions of 300 batches, where each batch runs as many episodes as it takes to get 200 time steps of data. The solid lines are the mean value at each epoch across all sessions, and the shaded areas are the 95% confidence intervals. So we can see that the natural policy gradient starts to hit a reward of 200 inside 100 batches, and that the mean stays higher than normal policy gradient even after 300 batches.

It’s also worth noting that for both the policy gradient and natural policy gradient the average time to run a batch and weight update was about the same on my machine, between 150-180 milliseconds.

The modified code for policy gradient, the natural policy gradient, and plotting code are all up on my GitHub.

Other notes

  • Had to reduce the \delta to 0.001 (from 0.05 which was used in the Rajeswaran paper) to get the learning to converge, suspect this is because of the greatly reduced complexity of the space and control problem.
  • The Rajeswaran paper algorithm does gradient ascent instead of descent, which is why the signs are how they are.
Tagged , , ,

Deep learning for control using augmented Hessian-free optimization

Traditionally, deep learning is applied to feed-forward tasks, like classification, where the output of the network doesn’t affect the input to the network. It is a decidedly harder problem when the output is recurrently connected such that network output affects the input. Because of this application of deep learning methods to control was largely unexplored until a few years ago. Recently, however, there’s been a lot of progress and research in this area. In this post I’m going to talk about an implementation of deep learning for control presented by Dr. Ilya Sutskever in his thesis Training Recurrent Neural Networks.

In his thesis, Dr. Sutskever uses augmented Hessian-free (AHF) optimization for learning. There are a bunch of papers and posts that go into details about AHF, here’s a good one by Andrew Gibiansky up on his blog, that I recommend you check out. I’m not going to really talk much here about what AHF is specifically, or how it differs from other methods, if you’re unfamiliar there are lots of places you can read up on it. Quickly, though, AHF is kind of a bag of tricks you can use with a fast method for estimating the curvature of the loss function with respect to the weights of a neural network, as well as the gradient, which allows it to make larger updates in directions where the loss function doesn’t change quickly. So rather than estimating the gradient and then doing a small update along each dimension, you can make the size of your update large in directions that change slowly and small along dimensions where things change quickly. And now that’s enough about that.

In this post I’m going to walk through using a Hessian-free optimization library (version 0.3.8) written by my compadre Dr. Daniel Rasmussen to train up a neural network to train up a 2-link arm, and talk about the various hellish gauntlets you need run to get something that works. Whooo! The first thing to do is install this Hessian-free library, linked above.

I’ll be working through code edited a bit for readability, to find the code in full you can check out the files up on my GitHub.

Build the network

Dr. Sutskever specified the structure of the network in his thesis to be 4 layers: 1) a linear input layer, 2) 100 Tanh nodes, 3) 100 Tanh nodes, 4) linear output layer. The network is connected up with the standard feedforward connections from 1 to 2 to 3 to 4, plus recurrent connections on 2 and 3 to themselves, plus a ‘skip’ connection from layer 1 to layer 3. Finally, the input to the network is the target state for the plant and the current state of the plant. So, lots of recursion! Here’s a picture:

network structure

The output layer connects in to the plant, and, for those unfamiliar with control theory terminology, ‘plant’ just means the system that you’re controlling. In this case an arm simulation.

Before we can go ahead and set up the network that we want to train, we also need to specify the loss function that we’re going to be using during training. The loss function in Ilya’s thesis is a standard one:

L(\theta) = \sum\limits^{N-1}\limits_{t=0} \ell(\textbf{u}_t) + \ell_f(\textbf{x}_N),
\ell(\textbf{u}_t) = \alpha \frac{||\textbf{u}_t||^2}{2},
\ell_f(\textbf{x}_N) = \frac{||\textbf{x}^* - \textbf{x}_t||^2}{2}

where L(\theta) is the total cost of the trajectory generated with \theta, the set of network parameters, \ell(\textbf{u}) is the immediate state cost, \ell_f(\textbf{x}) is the final state cost, \textbf{x} is the state of the arm, \textbf{x}^* is the target state of the arm, \textbf{u} is the control signal (torques) that drives the arm, and \alpha is a gain value.

To code this up using the hessianfree library we do:

from hessianfree import RNNet
from hessianfree.nonlinearities import (Tanh, Linear, Plant)
from hessianfree.loss_funcs import SquaredError, SparseL2

l2gain = 10e-3 * dt # gain on control signal loss
rnn = RNNet(
    # specify the number of nodes in each layer
    shape=[num_states * 2, 96, 96, num_states, num_states],
    # specify the function of the nodes in each layer
    layers=[Linear(), Tanh(), Tanh(), Linear(), plant],
    # specify the layers that have recurrent connections
    # specify the connections between layers
    conns={0:[1, 2], 1:[2], 2:[3], 3:[4]},
    # specify the loss function
        # squared error between plant output and targets
        # penalize magnitude of control signal (output of layer 3)
        SparseL2(l2gain, layers=[3])],

Note that if you want to run it on your GPU you’ll need PyCuda and sklearn installed. And a GPU.

An important thing to note as well is that in Dr. Sustkever’s thesis when we’re calculating the squared error of the distance from the arm state to the target, this is measured in joint angles. So it’s kind of a weird set up to be looking at the movement of the hand but have your cost function in joint-space instead of end-effector space, but it definitely simplifies training by making the cost more directly relatable to the control signal. So we need to calculate the joint angles of the arm that will have the hand at different targets around a circle. To do this we’ll take advantage of our inverse kinematics solver from way back when, and use the following code:

def gen_targets(arm, n_targets=8, sig_len=100):
    #Generate target angles corresponding to target
    #(x,y) coordinates around a circle
    import scipy.optimize

    x_bias = 0
    if arm.DOF == 2:
        y_bias = .35
        dist = .075
    elif arm.DOF == 3:
        y_bias = .5
        dist = .2

    # set up the reaching trajectories around circle
    targets_x = [dist * np.cos(theta) + x_bias \
                    for theta in np.linspace(0, np.pi*2, 65)][:-1]
    targets_y = [dist * np.sin(theta) + y_bias \
                    for theta in np.linspace(0, np.pi*2, 65)][:-1]

    joint_targets = []
    for ii in range(len(targets_x)):
    targs = np.asarray(joint_targets)

    # repeat the targets over time
    for ii in range(targs.shape[1]-1):
        targets = np.concatenate(
            (np.outer(targs[:, ii], np.ones(sig_len))[:, :, None],
             np.outer(targs[:, ii+1], np.ones(sig_len))[:, :, None]), axis=-1)
    targets = np.concatenate((targets, np.zeros(targets.shape)), axis=-1)
    # only want to penalize the system for not being at the
    # target at the final state, set everything before to np.nan
    targets[:, :-1] = np.nan 

    return targets

And you can see in the last couple lines that to implement the distance to target as a final state cost penalty only we just set all of the targets before the final time step equal to np.nan. If we wanted to penalize distance to target throughout the whole trajectory we would just comment that line out.

Create the plant

You’ll notice in the code that defines our RNN I set the last layer of the network to be plant, but that that’s not defined anywhere. Let’s talk. There are a couple of things that we’re going to need to incorporate our plant into this network and be able to use any deep learning method to train it. We need to be able to:

  1. Simulate the plant forward; i.e. pass in input and get back the resulting plant state at the next timestep.
  2. Calculate the derivative of the plant state with respect to the input; i.e. how do small changes in the input affect the state.
  3. Calculate the derivative of the plant state with respect to the previous state; i.e. how do small changes in the plant state affect the state at the next timestep.
  4. Calculate the derivative of the plant output with respect to its state; i.e. how do small changes in the current position of the state affect the output of the plant.

So 1 is easy, we have the arm simulations that we want already, they’re up on my GitHub. Number 4 is actually trivial too, because the output of our plant is going to be the state itself, so the derivative of the output with respect to the state is just the identity matrix.

For 2 and 3, we’re going to need to calculate some derivatives. If you’ve read the last few posts you’ll note that I’m on a finite differences kick. So let’s get that going! Because no one wants to calculate derivatives!

Important note, the notation in these next couple pieces of code is going to be a bit different from my normal notation because they’re matching with the hessianfree library notation, which is coming from a reinforcement learning literature background instead of a control theory background. So, s is the state of the plant, and x is the input to the plant. I know, I know. All the same, make sure to keep that in mind.

# calculate ds0/dx0 with finite differences
d_input_FD = np.zeros((x.shape[0], # number of trials
                       x.shape[1], # number of inputs
                       self.state.shape[1])) # number of states
for ii in range(x.shape[1]):
    # calculate state adding eps to x[ii]
    inc_x = x.copy()
    inc_x[:, ii] += self.eps
    state_inc = self.state.copy()
    # calculate state subtracting eps from x[ii]
    dec_x = x.copy()
    dec_x[:, ii] -= self.eps
    state_dec = self.state.copy()

    d_input_FD[:, :, ii] = \
        (state_inc - state_dec) / (2 * self.eps)
d_input_FD = d_input_FD[..., None]

Alrighty. First we create a tensor to store the results. Why is it a tensor? Because we’re going to be doing a bunch of runs at once. So our state dimensions are actually trials x size_input. When we then take the partial derivative, we end up with trials many size_input x size_state matrices. Then we increase each of the parameters of the input slightly one at a time and store the results, decrease them all one at a time and store the results, and compute our approximation of the gradient.

Next we’ll do the same for calculating the derivative of the state with respect to the previous state.

# calculate ds1/ds0
d_state_FD = np.zeros((x.shape[0], # number of trials
                       self.state.shape[1], # number of states
                       self.state.shape[1])) # number of states
for ii in range(self.state.shape[1]):
    # calculate state adding eps to self.state[ii]
    state = np.copy(self.prev_state)
    state[:, ii] += self.eps
    state_inc = self.state.copy()
    # calculate state subtracting eps from self.state[ii]
    state = np.copy(self.prev_state)
    state[:, ii] -= self.eps
    state_dec = self.state.copy()

    d_state_FD[:, :, ii] = \
        (state_inc - state_dec) / (2 * self.eps)
d_state_FD = d_state_FD[..., None]

Great! We’re getting closer to having everything we need. Another thing we need is a wrapper for running our arm simulation. It’s going to look like this:

def activation(self, x):
    state = []
    # iterate through and simulate the plant forward
    # for each trial
    for ii in range(x.shape[0]):
        self.arm.reset(q=self.state[ii, :self.arm.DOF],
                       dq=self.state[ii, self.arm.DOF:])
        state.append(np.hstack([self.arm.q, self.arm.dq]))
    state = np.asarray(state)

    self.state = self.squashing(state)

This is definitely not the fastest code to run. Much more ideally we would put the state and input into vectors and do a single set of computations for each call to activation rather than having that for loop in there. Unfortunately, though, we’re not assuming that we have access to the dynamics equations / will be able to pass in vector states and inputs.

Looking at the above code that seems pretty clear what’s going on, except you might notice that last line calling self.squashing. What’s going on there?

The squashing function looks like this:

def squashing(self, x):
    index_below = np.where(x < -2*np.pi)
    x[index_below] = np.tanh(x[index_below]+2*np.pi) - 2*np.pi
    index_above = np.where(x > 2*np.pi)
    x[index_above] = np.tanh(x[index_above]-2*np.pi) + 2*np.pi
    return x

All that’s happening here is that we’re taking our input, and doing nothing to it as long as it doesn’t start to get too positive or too negative. If it does then we just taper it off and prevent it from going off to infinity. So running a 1D vector through this function we get:

This ends up being a pretty important piece of the code here. Basically it prevents wild changes to the weights during learning to result in the system breaking down. So the state of the plant can’t go off to infinity and cause an error to be thrown, stopping our simulation. But because the target state is well within the bounds of where the squashing function does nothing, post-training we’ll still be able to use the resulting network to control a system that doesn’t have this fail safe built in. Think of this function as training wheels that catch you only if you start falling.

With that, we no have pretty much all of the parts necessary to begin training our network!

Training the network

We’re going to be training this network on the centre-out reaching task, where you start at a centre point and reach to a bunch of target locations around a circle. I’m just going to be re-implementing the task as it was done in Dr. Sutskever’s thesis, so we’ll have 64 targets around the circle, and train using a 2-link arm. Here’s the code that we’ll use to actually run the training:

for ii in range(last_trial+1, num_batches):
    # train a bunch of batches using the same input every time
    # to allow the network a chance to minimize things with
    # stable input (speeds up training)
    err = rnn.run_batches(plant, targets=None,
              optimizer=HessianFree(CG_iter=96, init_damping=100))

    # save the weights to file, track trial and error
    # err = rnn.error(inputs)
    err = rnn.best_error
    name = 'weights/rnn_weights-trial%04i-err%.3f'%(ii, err)
    np.savez_compressed(name, rnn.W)

Training your own network

A quick aside: if you want to run this code yourself, get a real good computer, have an arm simulation ready, the hessianfree Python library installed, and download and run this file. (Note: I used version 0.3.8 of the hessianfree library, which you can install using pip install hessianfree==0.3.8) This will start training and save the weights into a weights/ folder, so make sure that that exists in the same folder as If you want to view the results of the training at any point run the file, which will load in the most recent version of the weights and plot the error so far. If you want to generate an animated plot like I have below run and then the last command from my post on generating animated gifs.

Another means of seeing the results of your trained up network is to use the controller I’ve implemented in my controls benchmarking suite, which looks for a set of saved weights in the controllers/weights folder, and will load it in and use it to generate command signals for the arm by running it with

python arm2_python ahf reach --dt=1e-2

where you replace arm2_python with whatever arm model you trained your model on. Note the --dt=1e-2 flag, that is important because the model was trained with a .01 timestep and things get a bit weird if you suddenly change the dynamics on the controller.

OK let’s look at some results!


OK full discretion, these results are not optimizing the cost function we discussed above. They’re implementing a simpler cost function that only looks at the final state, i.e. it doesn’t penalize the magnitude of the control signal. I did this because Dr. Sutskever says in his thesis he was able to optimize with just the final state cost using much smaller networks. I originally looked at neurons with 96 neurons in each layer, and it just took forgoddamnedever to run. So after running for 4 weeks (not joking) and needing to make some more changes I dropped the number of neurons and simplified the task.

The results below are from running a network with 32 neurons in each layer controlling this 2-link arm, and took another 4-5 weeks to train up.

Hey that looks good! Not bad, augmented Hessian-free learning, not bad. It had pretty consistent (if slow) decline in the error rate, with a few crazy bumps from which it quickly recovered. Also take note that each training iteration is actually 32 runs, so it’s not 12,50-ish runs it’s closer to 400,000 training runs that it took to get here.

One biggish thing that was a pain was that it turns out that I only trained the neural network for reaching in the one direction, and when you only train it to reach one way it doesn’t generalize to reaching back to the starting point (which, fair enough). But, I didn’t realize this until I was took the trained network and ran it in the benchmarking code, at which point I was not keen to redo all of the training it took to get the neural network to the level of accuracy it was at under a more complicated training set. The downside of this is that even though I’ve implemented a controller that takes in the trained network and uses it to control the arm, to do the reaching task I have to just do a hard reset after the arm reaches the target, because it can’t reach back to the center, like all the other controllers. All the same, here’s an animation of the trained up AHF controller reaching to 8 targets (it was trained on all 64 above though):


Things don’t always go so smoothly, though. Here’s results from another training run that took around 2-3 weeks, and uses a different 2-link arm model (translated from Matlab code written by Dr. Emo Todorov):


What I found frustrating about this was that if you look at the error over time then this arm is doing as well or better than the previous arm at a lot of points. But the corresponding trajectories look terrible, like something you would see in a horror movie based around getting good machine learning results. This of course comes down to how I specified the cost function, and when I looked at the trajectories plotted over time the velocity of the arm is right at zero at the final time step, which it is not quiiiitte the case for the first controller. So this second network has found a workaround to minimize the cost function I specified in a way I did not intend. To prevent this, doing something like weighting the distance to target heavier than non-zero velocity would probably work. Or possibly just rerunning the training with a different random starting point you could get out a better controller, I don’t have a great feel for how important the random initialization is, but I’m hoping that it’s not all too important and its effects go to zero with enough training. Also, it should be noted I’ve run the first network for 12,500 iterations and the second for less than 6,000, so I’ll keep letting them run and maybe it will come around. The first one looked pretty messy too until about 4,000 iterations in.

Training regimes

Frustratingly, the way that you train deep networks is very important. So, very much like the naive deep learning network trainer that I am, I tried the first thing that pretty much anyone would try:

  • run the network,
  • update the weights,
  • repeat.

This is what I’ve done in the results above. And it worked well enough in that case.

If you remember back to the iLQR I made a little while ago, I was able to change the cost function to be

L(\theta) = \sum\limits^{N-1}\limits_{t=0} \ell(\textbf{u}_t) + \ell_f(\textbf{x}_N),
\ell(\textbf{u}_t, \textbf{x}_t) = \alpha \frac{||\textbf{u}_t||^2}{2} + \frac{||\textbf{x}^* - \textbf{x}_t||^2}{2},
\ell_f(\textbf{x}_N) = \frac{||\textbf{x}^* - \textbf{x}_t||^2}{2}

(i.e. to include a penalty for distance to target throughout the trajectory and not just at the final time step) which resulted in straighter trajectories when controlling the 2-link arm. So I thought I would try this here as well. Sadly (incredibly sadly), this was fairly fruitless. The network didn’t really learn or improve much at all.

After much consideration and quandary on my part, I talked with Dr. Dan and he suggested that I try another method:

  • run the network,
  • record the input,
  • hold the input constant for a few batches of weight updating,
  • repeat.

This method gave much better results. BUT WHY? I hear you ask! Good question. Let me give giving explanation a go.

Essentially, it’s because the cost function is more complex now. In the first training method, the output from the plant is fed back into the network as input at every time step. When the cost function was simpler this was OK, but now we’re getting very different input to train on at every iteration. So the system is being pulled in different directions back and forth at every iteration. In the second training regime, the same input is given several times in a row, which let’s the system follow the same gradient for a few training iterations before things change again. In my head I picture this as giving the algorithm a couple seconds to catch its breath dunking it back underwater.

This is a method that’s been used in a bunch of places recently. One of the more high-profile instances is in the results published from DeepMind on deep RL for control and for playing Go. And indeed, it also works well here.

To implement this training regime, we set up the following code:

for ii in range(last_trial+1, num_batches):

    # run the plant forward once
    rnn.forward(input=plant, params=rnn.W)

    # get the input and targets from above rollout
    inputs = plant.get_vecs()[0].astype(np.float32)
    targets = np.asarray(plant.get_vecs()[1], dtype=np.float32)

    # train a bunch of batches using the same input every time
    # to allow the network a chance to minimize things with
    # stable input (speeds up training)
    err = rnn.run_batches(inputs, targets, max_epochs=batch_size,
              optimizer=HessianFree(CG_iter=96, init_damping=100))

    # save the weights to file, track trial and error
    # err = rnn.error(inputs)
    err = rnn.best_error
    name = 'weights/rnn_weights-trial%04i-err%.3f'%(ii, err)
    np.savez_compressed(name, rnn.W)

So you can see that we do one rollout with the weights, then go in and get the inputs and targets that were used in that rollout, and start training the network while holding those constant for batch_size epochs (training sessions). From a little bit of messing around I’ve found batch_size=32 to be a pretty good number. So then it runs 32 training iterations where it’s updating the weights, and then saves those weights (because we want a loooottttt of check-points) and then restarts the loop.

Embarrassingly, I’ve lost my simulation results from this trial, somehow…so I don’t have any nice plots to back up the above, unfortunately. But since this is just a blog post I figured I would at least talk about it a little bit, since people might still find it useful if they’re just getting into the field like me. and just update this post whenever I re-run them. If I rerun them.

What I do have, however, are results where this method doesn’t work! I tried this with the simpler cost function, that only looks at the final state distance from the target, and it did not go so well. Let’s look at that one!


My guess here is basically that the system has gotten to a point where it’s narrowed things down in the parameter space and now when you run 32 batches it’s overshooting. It needs feedback about its updates after every update at this point. That’s my guess, at least. So it could be the case that for more complex cost functions you’d want to train it while holding the input constant for a while, and then when the error starts to plateau switch to updating the input after every parameter update.


All in all, AHF for training neural networks in control is pretty awesome. There are of course still some major hold-backs, mostly related to how long it takes to train up a network, and having to guess at effective training regimes and network structures etc. But! It was able to train up a relatively small neural network to move an arm model from a center point to 64 targets around a circle, with no knowledge of the system under control at all. In Dr. Sutskever’s thesis he goes on to use the same set up under more complicated circumstances, such as when there’s a feedback delay, or a delay on the outgoing control signal, and unexpected noise etc, so it is able to learn under a number of different, fairly complex situations. Which is pretty slick.

Related to the insane training time required, I very easily could be missing some basic thing that would help speed things up. If you, reader, get ambitious and run the code on your own machine and find out useful methods for speeding up the training please let me know! Personally, my plan is to next investigate guided policy search, which seems like it’s found a way around this crazy training time.

Tagged , , , , , ,