An Example for Training MNIST with Bidirectional LSTM(BI-LSTM) in TensorFlow

Here is an example for train mnist image dataset using bi-lstm in TensorFlow.

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.ops.rnn_cell_impl import DropoutWrapper
import os
import numpy as np
import random

# Get Mnist Data
mnist = input_data.read_data_sets(os.getcwd() + "/MNIST-data/", one_hot=True)

# Variable
learning_rate = 1e-3
num_units = 256
num_layer = 3
input_size = 28
time_step = 28 # 28 lstm cells
total_steps = 1000
category_num = 10
steps_per_validate = 5
steps_per_test = 5
batch_size = tf.placeholder(tf.int32, [])  # set batch_size, its value will be feed is sess
keep_prob = tf.placeholder(tf.float32, [])     # the probility

# Initial
x = tf.placeholder(tf.float32, [None, 784])
y_label = tf.placeholder(tf.float32, [None, 10])
x_shape = tf.reshape(x, [-1, time_step, input_size]) # batch_size * 28 * 28


# Create forword lstm
fw_lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units)
fw_h0 = fw_lstm_cell.zero_state(batch_size, dtype=tf.float32) # batch_size * 256
fw_drop_cell = DropoutWrapper(fw_lstm_cell, output_keep_prob=keep_prob)

# Create backword lstm
bw_lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units)
bw_h0 = bw_lstm_cell.zero_state(batch_size, dtype=tf.float32) # batch_size * 256
bw_drop_cell = DropoutWrapper(bw_lstm_cell, output_keep_prob=keep_prob)

# Create Bi-LSTM
input_x = tf.unstack(x_shape, time_step, 1)
outputs,_,_=tf.nn.static_bidirectional_rnn(cell_fw=fw_drop_cell,cell_bw=bw_drop_cell,inputs=input_x,dtype=tf.float32)

# Output is timesetps * batch_size * input_size
output = tf.reduce_mean(outputs,0)

# Output Layer
w = tf.Variable(tf.truncated_normal([2 * num_units, category_num], stddev=0.1), dtype=tf.float32)
b = tf.Variable(tf.constant(0.1, shape=[category_num]), dtype=tf.float32)
y = tf.matmul(output, w) + b

# Loss
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y)
train = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cross_entropy)

# Prediction
correction_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))
tf.summary.scalar('accuracy',accuracy)
tf.summary.histogram("accuracy",accuracy)
# Train
#init
init = tf.global_variables_initializer()
merged = tf.summary.merge_all() 
with tf.Session() as sess:
    sess.run(init)
    test_file_writer = tf.summary.FileWriter(os.getcwd() + "/MNIST-test-logs/", sess.graph)

    for step in range(total_steps + 1):
        batch_x, batch_y = mnist.train.next_batch(100)
        sess.run(train, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5, batch_size: batch_x.shape[0]})
        # Train Accuracy
        #if step % steps_per_validate == 0:
            #print('Train', step, sess.run(accuracy, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5,
                                                               #batch_size: batch_x.shape[0]}))
        # Test Accuracy
        if step % steps_per_test == 0:
            test_x, test_y = mnist.test.images, mnist.test.labels
            s1 = sess.run(merged,feed_dict = {x: test_x, y_label: test_y, keep_prob: 1, batch_size: test_x.shape[0]})
            test_file_writer.add_summary(s1, step)
            print('Test', step,
                  sess.run(accuracy, feed_dict={x: test_x, y_label: test_y, keep_prob: 1, batch_size: test_x.shape[0]}))
            
    test_file_writer.close()

We can compare test accuracy with lstm.

Train MNIST with LSTM in TensorFlow

, , ,