Visualizing Neural Network Performance on High-Dimensional Data
Mon, 13 Mar 2017
Computer Science, Deep Learning, Machine Learning, Matplotlib, Neural Networks, Python, Software, Tensorflow
This post presents a short script that plots neural network performance on high-dimensional binary data using
MatPlotLib in Python. Binary vectors, or vectors only containing 0 and 1, can be useful for representing categorical data or discrete phenomena. The code in this post is
available on GitHub.
Encoded Binary Vectors
The data set is assumed to contain \(m\)-dimensional sample vectors associated with \(n\)-dimensional target vectors. As these vectors contain only 1s and 0s, the sample and target vectors can be considered as bit strings of length \(m\) and \(n\) respectively. These bit strings can then be thought of as unsigned integers. The following code converts a bit vector into its integer representation and back:
#Converts a binary vector (left to right format) to an integer
#x: A binary vector
#return: The corresponding integer
def BinVecToInt(x):
#Accumulator variable
num = 0
#Place multiplier
mult = 1
for i in x:
#Cast is needed to prevent conversion to floating point
num = num + int(i) * mult
#Multiply by 2
mult = len(v):
break
#Test if the LSB is set
if(x & 1 == 1):
#Set the bits in right-to-left order
v[c] = 1
#Onto the next column and bit
c += 1
x >>= 1
return v
The above code works for arbitrary size binary vectors as Python's built-in integer type can be of arbitrary size.
Animated Learning
Now the \(m\)-dimensional sample and \(n\)-dimensional target vectors can be transformed into integers. Using the above code, the original data-set can be condensed into 2 dimensions and can be plotted on the standard Cartesian coordinate plane using MatPlotLib. In the above transformation, the x-axis corresponds to the sample vectors and the y-axis corresponds to the target vectors. The following code uses MatPlotLib to produce an animation showing the target data and the model's prediction as successive training iterations pass.
#Plot the model R learning the data set A, Y
#R: A regression model
#A: The data samples
#Y: The target vectors
def PlotLearn(R, A, Y):
intA = [BinVecToInt(j) for j in A]
intY = [BinVecToInt(j) for j in Y]
fig, ax = mpl.subplots(figsize=(20, 10))
ax.plot(intA, intY, label ='Orig')
l, = ax.plot(intA, intY, label ='Pred')
ax.legend(loc = 'upper left')
#Updates the plot in ax as model learns data
def UpdateF(i):
R.fit(A, Y)
YH = R.predict(A)
S = MSE(Y, YH)
intYH = [BinVecToInt(j) for j in YH]
l.set_ydata(intYH)
ax.set_title('Iteration: ' + str(i * 64) + ' - MSE: ' + str(S))
return l,
ani = mpla.FuncAnimation(fig, UpdateF, frames = 2000, interval = 128, repeat = False)
#ani.save('foo.mp4') #ffmpeg is required to save the animation to an mp4
mpl.show()
return ani
In the above code, the nested function
UpdateF is known as a
closure. Since functions are
first-class citizens in Python, they can be created as local-variables inside a function. This is useful in the above code as
UpdateF can reference the MatPlotLib object in order to update the prediction data. Closures are a powerful if under-looked portion of Python that will be explored in a later topic.
Notice that the animation object is returned from the function. This is due to
an issue in MatPlotLib resulting from
garbage collection.
Results
Finally, the performance of the network can be visualized in real-time as the network is trained. In practice, the above code could be used to identify the point at which performance becomes satisfactory.
Figure 1: Video of Neural Network Performance over Time
Once the animation window is closed, execution begins after the call to
PlotLearn. At this point, the model could be written to a file, or used to perform prediction, etc.
Note: Spikes in the prediction graph are due to the fact that the
Hamming distance between two vectors \(\hat{x}\) and \(\hat{y}\) between two vectors can be small while the
Euclidean distance in the above encoding can be arbitrarily large. For example, \(0\) and \(2^{n}\) only differ in only one bit position.