An Example for Attention Mechanism In NLP

import tensorflow as tf
import numpy as np

x = np.array([[[0,1,2],[1,1,0],[1,2,0]],[[0,1,-1],[1,1,2],[1,2,0]]])
#shape = 2 * 3 * 3
print x

inputs = tf.convert_to_tensor(x,dtype=tf.float32)

w_omega = tf.Variable(tf.random_normal([3, 5], stddev=0.1))
b_omega = tf.Variable(tf.random_normal([5], stddev=0.1))
u_omega = tf.Variable(tf.random_normal([5], stddev=0.1))

wa = tf.Variable(tf.random_normal([3, 5], stddev=0.1))
ba = tf.Variable(tf.random_normal([5], stddev=0.1))
u = tf.Variable(tf.random_normal([5,1], stddev=0.1))


with tf.name_scope('v'):
    # Applying fully connected layer with non-linear activation to each of the B*T timestamps;
    #  the shape of `v` is (B,T,D)*(D,A)=(B,T,A), where A=attention_size
    v = tf.tanh(tf.tensordot(inputs, w_omega, axes=1) + b_omega)

    # For each of the timestamps its vector of size A from `v` is reduced with `u` vector
    vu = tf.tensordot(v, u_omega, axes=1, name='vu')  # (B,T) shape
    
    vuu = tf.tensordot(v, u, axes=1, name='vuu')

    alphas = tf.nn.softmax(vu, name='alphas')         # (B,T) shape
    alphas_2 = tf.nn.softmax(vuu, name='alphas', dim=1)         # (B,T) shape

    # Output of (Bi-)RNN is reduced with attention vector; the result has (B,D) shape
    t = tf.expand_dims(alphas, -1);
    output = inputs * tf.expand_dims(alphas, -1)
    output_2 =inputs * alphas_2


init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()

with tf.Session() as sess:
    sess.run([init, init_local])
    print(sess.run(alphas))
    print(sess.run(t))
    print(sess.run(alphas_2))
    print(sess.run(output))
    print(sess.run(output_2))

The result is:

[[0.3313902  0.34481937 0.3237904 ]
 [0.34879792 0.33012047 0.3210816 ]]
[[[0.3313902 ]
  [0.34481937]
  [0.3237904 ]]

 [[0.34879792]
  [0.33012047]
  [0.3210816 ]]]
[[[0.33401328]
  [0.32772106]
  [0.33826563]]

 [[0.31965455]
  [0.33990833]
  [0.3404371 ]]]
[[[ 0.          0.3313902   0.6627804 ]
  [ 0.34481937  0.34481937  0.        ]
  [ 0.3237904   0.6475808   0.        ]]

 [[ 0.          0.34879792 -0.34879792]
  [ 0.33012047  0.33012047  0.66024095]
  [ 0.3210816   0.6421632   0.        ]]]
[[[ 0.          0.33401328  0.66802657]
  [ 0.32772106  0.32772106  0.        ]
  [ 0.33826563  0.67653126  0.        ]]

 [[ 0.          0.31965455 -0.31965455]
  [ 0.33990833  0.33990833  0.67981666]
  [ 0.3404371   0.6808742   0.        ]]]

,