An Examle for Training and Testing Data with TensorFlow tf.data.Dataset

Here is an example for training and test data by input data.

#-*- coding: UTF-8 -*- 
import tensorflow as tf
import numpy as np
import os
import time

batch_size = 3
shuffle_buffer = 4
train_steps = 100

a = np.array([[1,2],[3,4],[5,6],[7,8],[9,10]])
b = np.array([[1,3],[3,3],[5,5]])
train_dataset = tf.data.Dataset.from_tensor_slices(a)

# parse each file
def parse(value):
    time.sleep(5)
    data = value
    label = value[1]
    print 'read data'
    return data,label

train_dataset = train_dataset.shuffle(shuffle_buffer).batch(batch_size)
train_dataset = train_dataset.map(parse)
train_iterator = train_dataset.make_initializable_iterator()

train_data_batch, train_lable_batch = train_iterator.get_next()

## test
test_dataset = tf.data.Dataset.from_tensor_slices(b)
test_dataset = test_dataset.batch(batch_size)
test_dataset = test_dataset.map(parse)
test_iterator = test_dataset.make_initializable_iterator()

test_data_batch, test_lable_batch = test_iterator.get_next()

init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()

def train(i, value, label):
    print "train %d" % i
    time.sleep(1)


with tf.Session() as sess:
    sess.run([init, init_local])
    sess.run(train_iterator.initializer)
    for i in range(0,train_steps):
        try:
            x,y = sess.run([train_data_batch, train_lable_batch])
            if x.shape[0] == batch_size:
                print x,y # must be here
                train(i, x, y)
        except tf.errors.OutOfRangeError:
            sess.run(train_iterator.initializer)           
        

        sess.run(test_iterator.initializer)
        while True:# test with all test data
            try:
                # some test elements will be igore:num = total_test_number - m * batch_size
                test_x, test_y = sess.run([test_data_batch, test_lable_batch])
                # get accuray
                if test_x.shape[0] == batch_size:
                    pass
            except tf.errors.OutOfRangeError:
                break

The resut is:

From the result, we can know:

  1. By reset sess.run(train_iterator.initializer) when occure tf.errors.OutOfRangeError, we can read data from 0 in dataset
  2. By if test_x.shape[0] == batch_size, we can sure the size of every batch train data is the same

But we can not do like this:

        try:
            x,y = sess.run([train_data_batch, train_lable_batch])                
        except tf.errors.OutOfRangeError:
            sess.run(train_iterator.initializer)  
            continue
         
        print x,y # it also be execute when tf.errors.OutOfRangeError occured
        train(i, x, y)

The result is:

From the result we can know:

  1. Although when tf.errors.OutOfRangeError occured, the code blow is also be executed.
  2. The size of every batch train data is not the same.

, ,