0%

TensorFlow_dynamic_rnn_AND_bidirectional_dynamic_rnn使用

dynamic_rnn 源码 https://github.com/tensorflow/tensorflow/blob/9590c4c32dd4346ea5c35673336f5912c6072bf2/tensorflow/contrib/recurrent/python/ops/functional_rnn.py

lstm_dynamic_rnn

1
2
3
4
5
6
7
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops

ops.reset_default_graph()

tf.__version__
'1.12.0'
1
inputs_tensor = tf.constant(np.random.random(size=(3, 4, 5)), dtype=tf.float32)
1
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=6)
1
output, state = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=inputs_tensor, dtype=tf.float32)
1
output
<tf.Tensor 'rnn/transpose_1:0' shape=(3, 4, 6) dtype=float32>
1
state
LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(3, 6) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(3, 6) dtype=float32>)
1
2
sess = tf.Session()
sess.run(tf.global_variables_initializer())
1
sess.run(output)
array([[[ 0.08616676,  0.00701726,  0.02770402,  0.06492911,
          0.0510945 ,  0.00969434],
        [ 0.1320394 ,  0.00555411,  0.00191859,  0.08497527,
          0.07742812,  0.05351048],
        [ 0.22318284,  0.00047223,  0.00479356, -0.07421347,
          0.10626906,  0.04756515],
        [ 0.2620672 ,  0.00858223, -0.00066401, -0.05911936,
          0.12364249,  0.03462988]],

       [[ 0.06325776, -0.01175407,  0.04182662,  0.02187428,
          0.07014529, -0.01354775],
        [ 0.1437409 ,  0.01115061,  0.03309705,  0.00872881,
          0.1256691 , -0.01869454],
        [ 0.1477679 ,  0.0219029 , -0.00088758, -0.04972449,
          0.10450882, -0.00912069],
        [ 0.24747844, -0.02095445,  0.08351423, -0.00707687,
          0.0659603 , -0.01971625]],

       [[ 0.08307187,  0.00345412,  0.02962245,  0.06890179,
          0.03142677,  0.01744366],
        [ 0.1346873 , -0.05403996,  0.08703925, -0.00265777,
          0.04009543, -0.0071087 ],
        [ 0.14429979, -0.04602255,  0.04806704, -0.01820353,
          0.09431574,  0.00280121],
        [ 0.22852188, -0.05865859,  0.04669214, -0.07948805,
          0.06424228,  0.03090438]]], dtype=float32)
1
sess.run(state.h)
array([[ 0.2620672 ,  0.00858223, -0.00066401, -0.05911936,  0.12364249,
         0.03462988],
       [ 0.24747844, -0.02095445,  0.08351423, -0.00707687,  0.0659603 ,
        -0.01971625],
       [ 0.22852188, -0.05865859,  0.04669214, -0.07948805,  0.06424228,
         0.03090438]], dtype=float32)
1
sess.run(state.c)
array([[ 0.5363698 ,  0.02536838, -0.00174306, -0.0865218 ,  0.2190451 ,
         0.06242979],
       [ 0.45945585, -0.04819317,  0.1650849 , -0.01152224,  0.12156412,
        -0.03720763],
       [ 0.40406716, -0.15820494,  0.11020815, -0.12061047,  0.11127482,
         0.0575163 ]], dtype=float32)

gru_dynamic_rnn

1
2
3
4
5
6
7
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops

ops.reset_default_graph()

tf.__version__
'1.12.0'
1
inputs_tensor = tf.constant(np.random.random(size=(3, 4, 5)), dtype=tf.float32)
1
gru_cell = tf.nn.rnn_cell.GRUCell(num_units=6)
1
output, state = tf.nn.dynamic_rnn(cell=gru_cell, inputs=inputs_tensor, dtype=tf.float32)
1
output
<tf.Tensor 'rnn/transpose_1:0' shape=(3, 4, 6) dtype=float32>
1
state
<tf.Tensor 'rnn/while/Exit_3:0' shape=(3, 6) dtype=float32>
1
2
sess = tf.Session()
sess.run(tf.global_variables_initializer())
1
sess.run(output)
array([[[ 0.10332259, -0.11011579,  0.02266729,  0.0007517 ,
          0.03512604,  0.04907914],
        [ 0.20385078, -0.23856235,  0.1248817 , -0.08083571,
          0.11375216,  0.1124958 ],
        [ 0.3623113 , -0.32147223,  0.0005917 ,  0.05736844,
          0.19632787,  0.17149314],
        [ 0.3957892 , -0.25175768,  0.13963164, -0.01752499,
          0.19986765,  0.20010342]],

       [[ 0.09015381, -0.01318468,  0.04039504,  0.00195809,
         -0.03166562, -0.00577726],
        [ 0.1832312 , -0.00861337,  0.07086013, -0.05410649,
         -0.11791474, -0.12356216],
        [ 0.262054  , -0.21013457,  0.05953048,  0.00513211,
          0.00651281, -0.02495475],
        [ 0.37156823, -0.19515839,  0.18966836, -0.04400685,
          0.0111442 , -0.03699975]],

       [[ 0.11565635, -0.11679434, -0.06442562,  0.05251991,
         -0.01245365, -0.00727599],
        [ 0.17201838, -0.24277222, -0.07361332, -0.03645991,
          0.10429659, -0.00199607],
        [ 0.26178688, -0.21496084, -0.05825588, -0.05373647,
          0.09104574, -0.04549351],
        [ 0.3539365 , -0.20948091, -0.02887814, -0.00197285,
          0.06114171, -0.05429474]]], dtype=float32)
1
sess.run(state)
array([[ 0.3957892 , -0.25175768,  0.13963164, -0.01752499,  0.19986765,
         0.20010342],
       [ 0.37156823, -0.19515839,  0.18966836, -0.04400685,  0.0111442 ,
        -0.03699975],
       [ 0.3539365 , -0.20948091, -0.02887814, -0.00197285,  0.06114171,
        -0.05429474]], dtype=float32)

Multi_lstm_dynamic_rnn

1
2
3
4
5
6
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops

ops.reset_default_graph()
tf.__version__
 '1.12.0'
1
inputs_tensor = tf.constant(np.random.random(size=(3, 4, 5)), dtype=tf.float32)
1
2
3
lstm_cell_units_list = [32,16,8]
lstm_cells = [tf.nn.rnn_cell.LSTMCell(num_units=unit) for unit in lstm_cell_units_list]
multi_lstm_cells = tf.nn.rnn_cell.MultiRNNCell(lstm_cells)
1
outputs, states = tf.nn.dynamic_rnn(cell=multi_lstm_cells, inputs=inputs_tensor, dtype=tf.float32)
1
outputs
 <tf.Tensor 'rnn/transpose_1:0' shape=(3, 4, 8) dtype=float32>
1
states
 (LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(3, 32) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(3, 32) dtype=float32>),
  LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(3, 16) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(3, 16) dtype=float32>),
  LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(3, 8) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(3, 8) dtype=float32>))
1
2
sess = tf.Session()
sess.run(tf.global_variables_initializer())
1
sess.run(outputs)
 array([[[-1.61404535e-03, -5.39962610e-04,  5.39611437e-06,
           2.28725166e-05,  9.44484840e-04, -4.58947802e-03,
           1.46846834e-03, -7.76402245e-04],
         [-3.09936469e-03, -4.27109393e-04,  1.62538933e-03,
          -3.19648359e-04,  1.73521251e-03, -1.22727510e-02,
           4.70003067e-03, -2.35606125e-03],
         [-2.51045800e-03, -7.78585789e-04,  2.94276653e-03,
           7.99274276e-05,  2.68396758e-03, -2.12856531e-02,
           9.49617289e-03, -2.14548036e-03],
         [ 1.16307237e-04, -4.06709267e-04,  4.33099223e-03,
           1.24737364e-03,  3.62196262e-03, -3.07264701e-02,
           1.57848410e-02, -7.23684439e-04]],

        [[-3.23332497e-05,  2.67408788e-04,  7.24235782e-04,
          -2.50815327e-04,  2.83148460e-04, -2.27145315e-03,
           1.13080570e-03, -1.57606765e-03],
         [ 2.05371980e-04, -1.17544514e-04,  7.81793962e-04,
           1.08315377e-04,  1.33984850e-03, -7.61773949e-03,
           3.80030414e-03, -2.82652606e-03],
         [ 8.20874877e-04, -1.12650392e-03, -8.80288571e-05,
           1.86283619e-03,  2.56655412e-03, -1.47406263e-02,
           7.87232257e-03, -2.76046596e-03],
         [ 2.21824460e-03, -2.74996134e-03, -1.48778886e-03,
           5.04738186e-03,  3.63758323e-03, -2.35401765e-02,
           1.35801937e-02, -2.54129106e-03]],

        [[-8.64092377e-04, -3.04833520e-04, -1.56841183e-04,
          -1.28708722e-04,  9.67342989e-04, -2.75324681e-03,
           8.15562380e-04, -6.81574922e-04],
         [-2.73647415e-03, -1.36391085e-03, -5.19481488e-04,
          -4.36288465e-05,  2.77121575e-03, -9.15816333e-03,
           2.90221907e-03, -5.05329692e-04],
         [-3.75988754e-03, -3.24360654e-03, -1.28453283e-03,
           4.84427328e-05,  5.25222160e-03, -1.69832408e-02,
           5.55129535e-03, -2.57142878e-04],
         [-4.63895453e-03, -5.13906591e-03, -1.28147972e-03,
           4.88204561e-04,  7.37427827e-03, -2.69612260e-02,
           9.37340409e-03, -1.25891145e-03]]], dtype=float32)
1
sess.run(states[2].h)
 array([[ 0.00011631, -0.00040671,  0.00433099,  0.00124737,  0.00362196,
         -0.03072647,  0.01578484, -0.00072368],
        [ 0.00221824, -0.00274996, -0.00148779,  0.00504738,  0.00363758,
         -0.02354018,  0.01358019, -0.00254129],
        [-0.00463895, -0.00513907, -0.00128148,  0.0004882 ,  0.00737428,
         -0.02696123,  0.0093734 , -0.00125891]], dtype=float32)
1
sess.run(states[2].c)
 array([[ 0.00022726, -0.00082744,  0.00844097,  0.00248186,  0.00721051,
         -0.06110654,  0.03194349, -0.00143018],
        [ 0.00437538, -0.00561603, -0.00291193,  0.01000925,  0.00718385,
         -0.04696855,  0.02727731, -0.00508054],
        [-0.00914405, -0.01050362, -0.00249997,  0.00097318,  0.01467903,
         -0.05430028,  0.01892203, -0.00251437]], dtype=float32)

Multi_gru_dynamic_rnn

1
2
3
4
5
6
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops

ops.reset_default_graph()
tf.__version__
'1.12.0'
1
inputs_tensor = tf.constant(np.random.random(size=(3, 4, 5)), dtype=tf.float32)
1
2
3
gru_cell_units_list = [32,16,8]
gru_cells = [tf.nn.rnn_cell.GRUCell(num_units=unit) for unit in gru_cell_units_list]
multi_gru_cells = tf.nn.rnn_cell.MultiRNNCell(gru_cells)
1
outputs, states = tf.nn.dynamic_rnn(cell=multi_gru_cells, inputs=inputs_tensor, dtype=tf.float32)
1
outputs
<tf.Tensor 'rnn/transpose_1:0' shape=(3, 4, 8) dtype=float32>
1
states
(<tf.Tensor 'rnn/while/Exit_3:0' shape=(3, 32) dtype=float32>,
 <tf.Tensor 'rnn/while/Exit_4:0' shape=(3, 16) dtype=float32>,
 <tf.Tensor 'rnn/while/Exit_5:0' shape=(3, 8) dtype=float32>)
1
2
sess = tf.Session()
sess.run(tf.global_variables_initializer())
1
sess.run(outputs)
array([[[-3.0689167e-03,  4.7762650e-03, -5.6350869e-03,  7.3930359e-04,
          2.7776770e-03,  1.6547712e-03,  4.0468639e-03, -1.6071725e-03],
        [-9.6203415e-03,  1.0396119e-02, -1.2028635e-02, -7.5696781e-04,
          9.5890947e-03,  3.0168635e-03,  9.1733877e-03, -6.7313079e-04],
        [-1.6855957e-02,  2.2866743e-02, -2.2569295e-02, -3.3538719e-03,
          2.0547308e-02, -3.6199810e-05,  1.4878229e-02,  8.4475791e-03],
        [-2.5817569e-02,  4.0700834e-02, -3.5192616e-02, -8.7473392e-03,
          3.5584368e-02, -9.0519777e-03,  2.4247481e-02,  2.3476347e-02]],

       [[-3.0218370e-03,  6.4162901e-03, -6.9783311e-03,  3.8428712e-04,
          5.0610453e-03, -6.1881670e-04,  3.7080410e-04,  5.2686040e-03],
        [-1.0377670e-02,  1.5972832e-02, -1.6679505e-02, -9.3015225e-04,
          1.7201036e-02, -1.5078825e-03,  1.3457731e-03,  1.2340306e-02],
        [-2.2288060e-02,  2.7055936e-02, -2.7431497e-02, -4.4479482e-03,
          3.2678001e-02, -2.9535883e-03,  7.6549058e-03,  1.8373087e-02],
        [-3.6172852e-02,  4.3594681e-02, -4.2179834e-02, -3.7945234e-03,
          4.7429617e-02, -2.6138786e-03,  1.7336009e-02,  2.5696296e-02]],

       [[-7.6424627e-04,  7.8961477e-03, -7.1330541e-03,  1.9744530e-03,
          4.7649448e-03, -3.9919489e-04, -1.7578194e-04,  3.7013867e-03],
        [-1.6916599e-03,  2.1197245e-02, -1.7833594e-02,  1.4740386e-03,
          1.3453390e-02, -5.6854654e-03,  1.9273567e-03,  1.2227280e-02],
        [-5.0865025e-03,  3.9313547e-02, -3.0530766e-02, -3.3563119e-05,
          2.9243957e-02, -1.4064199e-02,  4.2982912e-03,  2.5105122e-02],
        [-1.1633213e-02,  6.1902620e-02, -4.5463178e-02, -2.4870001e-03,
          5.1334824e-02, -2.5908694e-02,  7.0450706e-03,  4.3221094e-02]]],
      dtype=float32)
1
sess.run(states[2])
array([[-0.02581757,  0.04070083, -0.03519262, -0.00874734,  0.03558437,
        -0.00905198,  0.02424748,  0.02347635],
       [-0.03617285,  0.04359468, -0.04217983, -0.00379452,  0.04742962,
        -0.00261388,  0.01733601,  0.0256963 ],
       [-0.01163321,  0.06190262, -0.04546318, -0.002487  ,  0.05133482,
        -0.02590869,  0.00704507,  0.04322109]], dtype=float32)

Multi_lstm_bidirectional_dynamic_rnn

1
2
3
4
5
6
7
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops

ops.reset_default_graph()

tf.__version__
'1.12.0'
1
inputs_tensor = tf.constant(np.random.random(size=(3, 4, 5)), dtype=tf.float32)
1
2
3
4
5
6
cell_fw_units_list = [32,16,8]
cell_bw_units_list = [33,17,9]
cell_fw = [tf.nn.rnn_cell.LSTMCell(num_units=unit) for unit in cell_fw_units_list]
cell_bw = [tf.nn.rnn_cell.LSTMCell(num_units=unit) for unit in cell_bw_units_list]
lstm_forward = tf.nn.rnn_cell.MultiRNNCell(cells=cell_fw)
lstm_backword = tf.nn.rnn_cell.MultiRNNCell(cells=cell_bw)
1
2
outputs, states = tf.nn.bidirectional_dynamic_rnn(
cell_fw=lstm_forward,cell_bw=lstm_backword,inputs=inputs_tensor,dtype=tf.float32)
1
outputs
(<tf.Tensor 'bidirectional_rnn/fw/fw/transpose_1:0' shape=(3, 4, 8) dtype=float32>,
 <tf.Tensor 'ReverseV2:0' shape=(3, 4, 9) dtype=float32>)
1
states
((LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(3, 32) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_4:0' shape=(3, 32) dtype=float32>),
  LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_5:0' shape=(3, 16) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_6:0' shape=(3, 16) dtype=float32>),
  LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_7:0' shape=(3, 8) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_8:0' shape=(3, 8) dtype=float32>)),
 (LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_3:0' shape=(3, 33) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_4:0' shape=(3, 33) dtype=float32>),
  LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_5:0' shape=(3, 17) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_6:0' shape=(3, 17) dtype=float32>),
  LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_7:0' shape=(3, 9) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_8:0' shape=(3, 9) dtype=float32>)))
1
outputs[0]
<tf.Tensor 'bidirectional_rnn/fw/fw/transpose_1:0' shape=(3, 4, 8) dtype=float32>
1
outputs[1]
<tf.Tensor 'ReverseV2:0' shape=(3, 4, 9) dtype=float32>
1
states[0]
(LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(3, 32) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_4:0' shape=(3, 32) dtype=float32>),
 LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_5:0' shape=(3, 16) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_6:0' shape=(3, 16) dtype=float32>),
 LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_7:0' shape=(3, 8) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_8:0' shape=(3, 8) dtype=float32>))
1
states[1]
(LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_3:0' shape=(3, 33) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_4:0' shape=(3, 33) dtype=float32>),
 LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_5:0' shape=(3, 17) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_6:0' shape=(3, 17) dtype=float32>),
 LSTMStateTuple(c=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_7:0' shape=(3, 9) dtype=float32>, h=<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_8:0' shape=(3, 9) dtype=float32>))
1
2
sess = tf.Session()
sess.run(tf.global_variables_initializer())
1
sess.run(outputs[0])
array([[[-0.00076688, -0.00200841, -0.00236855,  0.00039911,
         -0.00145523, -0.00153653, -0.00134254,  0.00106617],
        [-0.00454114, -0.00330769, -0.00736781,  0.00173002,
         -0.0032972 , -0.00345388, -0.00293425,  0.00292132],
        [-0.01027342, -0.00316362, -0.01381279,  0.00462259,
         -0.00577486, -0.00438865, -0.00475066,  0.00504041],
        [-0.01815201, -0.00150841, -0.0213333 ,  0.00920272,
         -0.00753702, -0.0049286 , -0.00753725,  0.00764989]],

       [[-0.00051021, -0.00109504, -0.00235562,  0.00033509,
         -0.00184875, -0.00063042, -0.00011131,  0.00110563],
        [-0.00254402, -0.00287243, -0.00706615,  0.00204808,
         -0.00484681, -0.00206984, -0.0014233 ,  0.00308458],
        [-0.00708036, -0.00366582, -0.01411008,  0.00544437,
         -0.0080783 , -0.0034703 , -0.00328208,  0.00583236],
        [-0.01305352, -0.00295205, -0.02222473,  0.01060574,
         -0.01145888, -0.00383543, -0.0055557 ,  0.0088768 ]],

       [[-0.00123942,  0.00074371, -0.00094162,  0.00037477,
         -0.00032651,  0.0001554 ,  0.00046975,  0.00046324],
        [-0.00429494,  0.00162228, -0.00351843,  0.00242352,
         -0.00078484, -0.00083089, -0.00064292,  0.00222571],
        [-0.00847503,  0.00332979, -0.00786664,  0.00663501,
         -0.00156324, -0.00165966, -0.00228085,  0.00515618],
        [-0.01371686,  0.00522377, -0.01415662,  0.01267656,
         -0.00229844, -0.00248059, -0.00501328,  0.00901466]]],
      dtype=float32)
1
sess.run(states[0][2].h)
array([[-0.01815201, -0.00150841, -0.0213333 ,  0.00920272, -0.00753702,
        -0.0049286 , -0.00753725,  0.00764989],
       [-0.01305352, -0.00295205, -0.02222473,  0.01060574, -0.01145888,
        -0.00383543, -0.0055557 ,  0.0088768 ],
       [-0.01371686,  0.00522377, -0.01415662,  0.01267656, -0.00229844,
        -0.00248059, -0.00501328,  0.00901466]], dtype=float32)
1
sess.run(states[0][2].c)
array([[-0.03619757, -0.00299038, -0.04237926,  0.01859355, -0.01471011,
        -0.00991257, -0.01535283,  0.01467649],
       [-0.02600891, -0.00582871, -0.04436516,  0.02145906, -0.0222083 ,
        -0.00767388, -0.01128497,  0.01682222],
       [-0.02723665,  0.0103671 , -0.0282324 ,  0.02551894, -0.00446999,
        -0.00497827, -0.01015829,  0.01708832]], dtype=float32)

Multi_gru_bidirectional_dynamic_rnn

1
2
3
4
5
6
7
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops

ops.reset_default_graph()

tf.__version__
'1.12.0'
1
inputs_tensor = tf.constant(np.random.random(size=(3, 4, 5)), dtype=tf.float32)
1
2
3
4
5
6
cell_fw_units_list = [32,16,8]
cell_bw_units_list = [33,17,9]
cell_fw = [tf.nn.rnn_cell.GRUCell(num_units=unit) for unit in cell_fw_units_list]
cell_bw = [tf.nn.rnn_cell.GRUCell(num_units=unit) for unit in cell_bw_units_list]
gru_forward = tf.nn.rnn_cell.MultiRNNCell(cells=cell_fw)
gru_backword = tf.nn.rnn_cell.MultiRNNCell(cells=cell_bw)
1
2
outputs, states = tf.nn.bidirectional_dynamic_rnn(
cell_fw=lstm_forward,cell_bw=lstm_backword,inputs=inputs_tensor,dtype=tf.float32)
1
outputs
(<tf.Tensor 'bidirectional_rnn/fw/fw/transpose_1:0' shape=(3, 4, 8) dtype=float32>,
 <tf.Tensor 'ReverseV2:0' shape=(3, 4, 9) dtype=float32>)
1
states
((<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(3, 32) dtype=float32>,
  <tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_4:0' shape=(3, 16) dtype=float32>,
  <tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_5:0' shape=(3, 8) dtype=float32>),
 (<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_3:0' shape=(3, 33) dtype=float32>,
  <tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_4:0' shape=(3, 17) dtype=float32>,
  <tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_5:0' shape=(3, 9) dtype=float32>))
1
outputs[0]
<tf.Tensor 'bidirectional_rnn/fw/fw/transpose_1:0' shape=(3, 4, 8) dtype=float32>
1
outputs[1]
<tf.Tensor 'ReverseV2:0' shape=(3, 4, 9) dtype=float32>
1
states[0]
(<tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(3, 32) dtype=float32>,
 <tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_4:0' shape=(3, 16) dtype=float32>,
 <tf.Tensor 'bidirectional_rnn/fw/fw/while/Exit_5:0' shape=(3, 8) dtype=float32>)
1
states[1]
(<tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_3:0' shape=(3, 33) dtype=float32>,
 <tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_4:0' shape=(3, 17) dtype=float32>,
 <tf.Tensor 'bidirectional_rnn/bw/bw/while/Exit_5:0' shape=(3, 9) dtype=float32>)
1
2
sess = tf.Session()
sess.run(tf.global_variables_initializer())
1
sess.run(outputs[0])
array([[[-6.90058898e-03, -2.82262132e-04,  2.41189310e-03,
         -1.32219959e-03, -1.17578653e-04, -1.02065690e-03,
         -3.85798234e-03, -6.75282069e-03],
        [-1.89461559e-02, -1.18712685e-03,  8.09784234e-03,
         -5.38685592e-04, -2.89135100e-03, -5.51631721e-03,
         -1.24855377e-02, -1.74822416e-02],
        [-3.18835154e-02, -3.26077780e-03,  1.75229590e-02,
          2.73799687e-03, -9.39253345e-03, -1.30872596e-02,
         -2.38550343e-02, -3.05617414e-02],
        [-4.67147976e-02, -6.17155526e-03,  2.97614653e-02,
          6.51888642e-03, -1.87356286e-02, -2.50062142e-02,
         -3.67445275e-02, -4.68713641e-02]],

       [[-8.41161050e-03, -2.88979820e-04,  3.92245734e-03,
          2.17063329e-03, -5.09262085e-04, -2.77552451e-03,
         -5.37689496e-03, -3.95025592e-03],
        [-2.47462429e-02, -1.29592582e-03,  1.10453246e-02,
          8.48548859e-03, -2.65943818e-03, -1.01975258e-02,
         -1.82780307e-02, -1.22013604e-02],
        [-3.99502814e-02, -3.65575845e-03,  1.86709110e-02,
          1.63780842e-02, -8.04547779e-03, -2.07924694e-02,
         -3.48701701e-02, -2.20657662e-02],
        [-5.38205951e-02, -7.08078314e-03,  2.95224246e-02,
          2.77038664e-02, -1.74766369e-02, -3.49066183e-02,
         -5.22553027e-02, -3.21215130e-02]],

       [[-4.29279124e-03, -1.64932735e-05,  5.42476214e-03,
          4.06091474e-03, -3.88002302e-03, -1.09710405e-03,
         -2.73891282e-03, -1.10962777e-03],
        [-1.24491388e-02, -5.43807575e-04,  1.43955573e-02,
          8.08995403e-03, -1.06180515e-02, -3.86995194e-03,
         -8.71607196e-03, -7.86763430e-03],
        [-2.33881716e-02, -1.67222926e-03,  2.73847207e-02,
          1.45246573e-02, -2.10378524e-02, -1.14183174e-02,
         -1.76941082e-02, -1.70471650e-02],
        [-3.55363414e-02, -3.23894341e-03,  4.07814831e-02,
          2.15009488e-02, -3.38361487e-02, -2.32368447e-02,
         -3.00667081e-02, -2.97910050e-02]]], dtype=float32)
1
sess.run(states[0][2])
array([[-0.0467148 , -0.00617156,  0.02976147,  0.00651889, -0.01873563,
        -0.02500621, -0.03674453, -0.04687136],
       [-0.0538206 , -0.00708078,  0.02952242,  0.02770387, -0.01747664,
        -0.03490662, -0.0522553 , -0.03212151],
       [-0.03553634, -0.00323894,  0.04078148,  0.02150095, -0.03383615,
        -0.02323684, -0.03006671, -0.029791  ]], dtype=float32)
本站所有文章和源码均免费开放,如您喜欢,可以请我喝杯咖啡