Predicting Time Series and Completing Drawings with Transformers
Predicting time series data can be hard. Predicting multivariate time series data is definitely hard. Predicting multivariate time series data where different variables are different types of data presents a unique and interesting challenge, one that myself and a partner recently attempted with great results.
Part of what makes multivariate time series hard is that there’s no definitive best way to go about it. For instance for a task like image classification it’s generally agreed that some version of convolutional neural networks will be the best tool for the job. For time series data there are various auto regressive models (arima, sarima, etc) and there are deep neural approaches (RNN, LSTM, Rocket, etc) but each method tends to have pros and cons and so no clear optimal strategy exists yet. One of the most promising models for time series prediction is the transformer. Transformers were originally designed for NLP and since their creation have pretty much taken over all language processing and have produced some incredible large scale models (i.e. BERT, GPT-2/3). With some work this same transformer architecture can be applied to time series.
When time series prediction is pictured generally the applications are in stock prediction, weather forecasting, logistics analysis etc. You would be forgiven for thinking a drawing of a cat should be treated as an image instead of a time series. However a drawing is made by a human moving a pen in a certain sequence on a page. The actual times don’t matter as much as the fact that the drawing are an ordered set of points, in fact, before we go further we’ll need to take a closer look at the data we’re working with.
Data
Several years ago google put out a fun little game called Quick, Draw! The user draws a simple image and the backend neural network tries to guess what it is correctly. The drawings then help the AI learn better to make better guesses for later users. It’s a fun way to spend a couple minutes. But, it got even more interesting once it had been out for a couple years; the full datasets of drawings were released, additionally Google released what they called Sketch RNN. A recurrent neural network architecture that could analyze the drawing to generate new randomly generated drawings, and even predict how the drawings could be completed if given a partial drawing. This had very good results and has a very interesting auto encoder architecture, but if you know about RNN’s they have a lot of limitations and problems. Problems that transformers don’t have. So let’s take a look at the data and see if we can use transformers (if you’re reading his you already know we can and did).
The datasets are divided by class for the classifier backend that google used, for instance cat, bus, wine bottle, or the Mona Lisa. Each of these sets has about 70,000 unique drawings made by Quick, Draw users. Rather than a traditional image file, each drawing is stored as an ordered list of the strokes used to make that image, each stroke is made up of the coordinate points along the line of the pen stroke. In order to have this data in the best form instead of absolute coordinates it’s better to use offsets from the last point, and instead of separating strokes with a new list entry they define a variable for each offset point called “pen state”. Pen states is a set of three binary variables (really one tri-nary since they’re dependent on each other). [1, 0, 0] if the pen is down and drawing. [0, 1, 0] if the pen is up, i.e. one pen stroke has ended and the offset value represents the distance to when the next pen stroke begins. [0, 0, 1] for when the whole drawing is done, no more drawing will occur after that point. So in total each data point is in the form [x, y, p1, p2, p3] where x and y are the distance offsets from the last point in each dimension, and P (p1, p2, p3) is the pen state at that point.
As an example, the first three points of a cat drawing in the set were: [[-18. 5. 1. 0. 0.], [-15. 16. 1. 0. 0.], [ -7. 13. 1. 0. 0.]…] The drawing moves 18 pixels left and 5 up, then 15 left and 16 up, followed by 7 left and 13 up. The pen is down the whole time so the drawing is connected between each of those points.
What this means is that each drawing is an ordered list of 5 dimensions. Because it’s ordered in the same way the user drew the points, it can be thought of as a series sequence in 5 dimensions and if we can account for the higher dimensionality we can use transformers to predict these sequences, thereby predicting the rest of drawing. But getting it to work with transformers is the hard part.
Transformers and our Architecture
I won’t go too in depth on general transformers themselves because there are so many great resources to learn from if you’re totally new to transformers. The idea behind transformers is using what’s called self-attention blocks to learn the relationship between the different points in a sequence. Additionally it relies on having a positional encoding so that the model can accurately learn where in the sequence a point is. Transformers take in two inputs, the input sequence itself (for a language model could be an english sentence) and the partial output sequence that it has produced so far (however much of the english sentence has already been translated into another language). Meaning transformers take advantage of the fact that knowing partial output can help your later prediction. For instance it’s easier to translate an english sentence to spanish if you already know the first word of the spanish sentence translation.
For our data the input and output take the same form, we want to predict later points in the drawing, rather than translate those points into another form. Since the input and output are the same we don’t need both the encoder and decoder halves of the transformer, we can use only the decoder since that is where the masked attention block is, which we need to predict the next points. (see above diagram for encoder vs decoder). A usual decoder has two stacked attention block sets, the first/lower set for masked attention, and the second/higher for attention between the masked attention output and the output of the encoder block. Since we won’t have an encoder block we won’t need the second/higher set of attention blocks. This all decoder architecture is similar to the approach used in the famous GPT-2/3 models.
Transformers originally were designed for NLP meaning they took in 1 dimension for the word. Although in practice it takes two dimensions because the word embeddings are usually done with some version of one-hot vectors, two dimensions is still less than the five we have for our sketch drawings. This is the next change we made for our model, replacing the embedding layer instead with a linear projection layer. That projects the data into the correct space to go into the next layers of the model.
That covers the input, but we also need to change the output of the model. The original transformer uses a softmax activation layer because for NLP it would predict the next word in a sequence. For ours, we have a mixture of offset points and pen labels. The natural way to go about this is to have a dual output, where the offset points are just a linear activation layer and the pen state is a softmax activation layer, then concatenate those two outputs. If we did that, with just those changes you could give a model inputs and train them based on that, but results leave a lot to be desired.
The Problem
If we do just that we’ll have to make a custom loss function where the offsets use mean squared error and the pen state uses cross entropy. This is a very logical approach because offset is a distance measure and pen state can be thought of as labelling. Making custom loss functions in Tensorflow/Keras is not a hard task and combining different types of loss in this way is relatively simple. The problem arises during the training and usage. When we train the model with this custom loss function it will have trouble learning both pen state and offset. We could get decent results with offset, for instance it could eventually learn the basic shape of a drawing like a cat, but then it would not be able to handle pen state and it would usually just never lift the pen up. If we adjust the loss function and/or hyper parameters to learn the pen state it could change the pen state (i.e. lift the pen) but then it wouldn’t be able to learn the shapes in the drawing. In order to solve this we had to implement a branched model.
In this context branched model means that we split the model partway through and had two separate loss functions, rather than just one combined. Our original model would use 6 decoder blocks stacks on top of each other. For a branched implementation we use 4 stacked on top of each other, then the model “branches” and splits off in two parts. Each of the branches takes the same input which is the output of the the 4 previous decoder blocks, and then has 2 more decoder blocks and an output corresponding to which block it is, one for offsets and one for pen state. So in total it’s 8 decoder blocks, only 2 more, but this new model means some of the decoder blocks are dedicated to just learning pen state, some to just offset, and some to both. When dealing with a branched model like this the loss function is really two loss functions. Since we are reducing two separate losses it’s possible that reducing one could really raise the other, but since pen state and offset value should be correlated in general reducing one will also reduce the other.
Results
In practice these changes were exactly what was needed to get great results. The decoder blocks dedicated to each specific part of the output meant that the model was able to more effectively learn the patterns for those variables and could much more reliably predict both offset and pen state. Prediction in this context means predicting the next point of a drawing, meaning if we start very small, perhaps just a circle or the first 5–10 point the model should be able to predict what the rest of the drawing may look like. With just a bit of training the predictions start to accurately resemble cats, with more training (30 or more epochs) it begins to very reliably draw a recognizable cat image from just a simple input.
These same techniques can be applied to any other time series and probably produce good results if given the right data and hyper parameters. In particular the branching aspect of the model will probably be able to be applied to a wide variety of data sets. Transformers revolutionized natural language processing because of their ability to more quickly and accurately learn and synthesize attention data. Going forward I fully expect transformers to take over many of fields of machine learning for those same reasons.
Code and links
my website: https://mag389.github.io/
Our code: https://github.com/95ktsmith/Sketch-Transformer
Quick, Draw data: https://quickdraw.withgoogle.com/data
Sketch-rnn: https://arxiv.org/pdf/1704.03477.pdf
The original transformer paper: https://arxiv.org/abs/1706.03762