An Example for Training MNIST with Attention LSTM in TensorFlow

We can use LSTM or BI-LSTM to train MNIST, Code below apply attention mechanis to lstm to train.

An Example for Training MNIST with LSTM in TensorFlow

An Example for Training MNIST with Bidirectional LSTM(BI-LSTM) in TensorFlow

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.ops.rnn_cell_impl import DropoutWrapper
import os
import numpy as np
import random

# Get Mnist Data
mnist = input_data.read_data_sets(os.getcwd() + "/MNIST-data/", one_hot=True)

# Variable
learning_rate = 1e-3
num_units = 256
num_layer = 3
input_size = 28
time_step = 28 # 28 lstm cells
total_steps = 1000
category_num = 10
steps_per_validate = 5
steps_per_test = 5
batch_size = tf.placeholder(tf.int32, [])  # set batch_size, its value will be feed is sess
keep_prob = tf.placeholder(tf.float32, [])     # the probility


# Get RNN Cell
def cell(num_units):
    cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units) # create a basic lstm cell
    return DropoutWrapper(cell, output_keep_prob=keep_prob)


# Initial
x = tf.placeholder(tf.float32, [None, 784])
y_label = tf.placeholder(tf.float32, [None, 10])
x_shape = tf.reshape(x, [-1, time_step, input_size]) # batch_size * 28 * 28

# RNN Layers
cells = tf.nn.rnn_cell.MultiRNNCell([cell(num_units) for _ in range(num_layer)]) # 3 layes RNN
h0 = cells.zero_state(batch_size, dtype=tf.float32) # batch_size * 256
output, hs = tf.nn.dynamic_rnn(cells, inputs=x_shape, initial_state=h0)

#output: batch_size * time_step * 256
#attention: batch_size * time_step * 1
#dot:

# Apply Attention to the last output.
wa = tf.Variable(tf.truncated_normal([num_units, num_units], stddev=0.1), dtype=tf.float32)
ba = tf.Variable(tf.constant(0.1, shape=[num_units]), dtype=tf.float32)

#u = tf.Variable(tf.truncated_normal([num_units, num_units], stddev=0.1), dtype=tf.float32)
#bu = tf.Variable(tf.constant(0.1, shape=[num_units]), dtype=tf.float32)

u = tf.Variable(tf.truncated_normal([num_units, 1], stddev=0.1), dtype=tf.float32)
bu = tf.Variable(tf.constant(0.1, shape=[1]), dtype=tf.float32)

# output: 1 * 256, wa: 256* 256, ba: 256
# u: 256 * 256,bu: 256

attention = tf.nn.softmax(tf.tensordot(tf.tanh(tf.tensordot(output, wa, axes=1) + ba),u,axes=1) + bu) #?
output = output * attention 

# output all cell mean: get batch_size * 256
output = tf.reduce_mean(output, 1) 

# Output Layer
w = tf.Variable(tf.truncated_normal([num_units, category_num], stddev=0.1), dtype=tf.float32)
b = tf.Variable(tf.constant(0.1, shape=[category_num]), dtype=tf.float32)
y = tf.matmul(output, w) + b

# Loss
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y)
train = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cross_entropy)

# Prediction
correction_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))
tf.summary.scalar('accuracy',accuracy)
tf.summary.histogram("accuracy",accuracy)
# Train
#init
init = tf.global_variables_initializer()
merged = tf.summary.merge_all() 
with tf.Session() as sess:
    sess.run(init)
    test_file_writer = tf.summary.FileWriter(os.getcwd() + "/MNIST-test-logs/", sess.graph)

    for step in range(total_steps + 1):
        batch_x, batch_y = mnist.train.next_batch(100)
        sess.run(train, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5, batch_size: batch_x.shape[0]})
        # Train Accuracy
        if step % steps_per_validate == 0:
            print('Train', step, sess.run(accuracy, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5,
                                                               batch_size: batch_x.shape[0]}))
        # Test Accuracy
        if step % steps_per_test == 0:
            test_x, test_y = mnist.test.images, mnist.test.labels
            s1 = sess.run(merged,feed_dict = {x: test_x, y_label: test_y, keep_prob: 1, batch_size: test_x.shape[0]})
            test_file_writer.add_summary(s1, step)
            print('Test', step,
                  sess.run(accuracy, feed_dict={x: test_x, y_label: test_y, keep_prob: 1, batch_size: test_x.shape[0]}))
            
    test_file_writer.close()

The accuracy is:

, ,