embedding_lookup( )的用法
關(guān)于tensorflow中embedding_lookup( )的用法,在Udacity的word2vec會(huì)涉及到,本文將通俗的進(jìn)行解釋
#!/usr/bin/env/python
# coding=utf-8
import tensorflow as tf
import numpy as np
input_ids = tf.placeholder(dtype=tf.int32, shape=[None])
embedding = tf.Variable(np.identity(5, dtype=np.int32))
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print(embedding.eval())
print(sess.run(input_embedding, feed_dict={input_ids:[1, 2, 3, 0, 3, 2, 1]}))
代碼中先使用palceholder定義了一個(gè)未知變量input_ids用于存儲(chǔ)索引,和一個(gè)已知變量embedding,是一個(gè)5*5的對(duì)角矩陣。
運(yùn)行結(jié)果為:
embedding = [[1 0 0 0 0]
[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]]
input_embedding = [[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[1 0 0 0 0]
[0 0 0 1 0]
[0 0 1 0 0]
[0 1 0 0 0]]
簡(jiǎn)單的講就是根據(jù)input_ids中的id,尋找embedding中的對(duì)應(yīng)元素。比如,input_ids=[1,3,5],則找出embedding中下標(biāo)為1,3,5的向量組成一個(gè)矩陣返回。
如果將input_ids改寫成下面的格式:
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
print(sess.run(input_embedding, feed_dict={input_ids:[[1, 2], [2, 1], [3, 3]]}))
輸出結(jié)果就會(huì)變成如下的格式:
[[[0 1 0 0 0]
[0 0 1 0 0]]
[[0 0 1 0 0]
[0 1 0 0 0]]
[[0 0 0 1 0]
[0 0 0 1 0]]]
對(duì)比上下兩個(gè)結(jié)果不難發(fā)現(xiàn),相當(dāng)于在np.array中直接采用下標(biāo)數(shù)組獲取數(shù)據(jù)。需要注意的細(xì)節(jié)是返回的tensor的dtype和傳入的被查詢的tensor的dtype保持一致;和ids的dtype無關(guān)。