Tensorflow 的NCE-Loss的實現(xiàn)和word2vec

這兩天因為實現(xiàn)mxnetnce-loss,因此研究了一下tensorflow的nce-loss的實現(xiàn)。所以總結(jié)一下。

先看看tensorflow的nce-loss的API:

def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
             num_true=1,
             sampled_values=None,
             remove_accidental_hits=False,
             partition_strategy="mod",
             name="nce_loss")

假設(shè)nce_loss之前的輸入數(shù)據(jù)是K維的,一共有N個類,那么

  • weight.shape = (N, K)
  • bias.shape = (N)
  • inputs.shape = (batch_size, K)
  • labels.shape = (batch_size, num_true)
  • num_true : 實際的正樣本個數(shù)
  • num_sampled: 采樣出多少個負樣本
  • num_classes = N
  • sampled_values: 采樣出的負樣本,如果是None,就會用不同的sampler去采樣。待會兒說sampler是什么。
  • remove_accidental_hits: 如果采樣時不小心采樣到的負樣本剛好是正樣本,要不要干掉
  • partition_strategy:對weights進行embedding_lookup時并行查表時的策略。TF的embeding_lookup是在CPU里實現(xiàn)的,這里需要考慮多線程查表時的鎖的問題。

nce_loss的實現(xiàn)邏輯如下:

  • _compute_sampled_logits: 通過這個函數(shù)計算出正樣本和采樣出的負樣本對應(yīng)的output和label
  • sigmoid_cross_entropy_with_logits: 通過 sigmoid cross entropy來計算output和label的loss,從而進行反向傳播。這個函數(shù)把最后的問題轉(zhuǎn)化為了num_sampled+num_real個兩類分類問題,然后每個分類問題用了交叉熵的損傷函數(shù),也就是logistic regression常用的損失函數(shù)。TF里還提供了一個softmax_cross_entropy_with_logits的函數(shù),和這個有所區(qū)別。

再來看看TF里word2vec的實現(xiàn),他用到nce_loss的代碼如下:

  loss = tf.reduce_mean(
      tf.nn.nce_loss(nce_weights, nce_biases, embed, train_labels,
                     num_sampled, vocabulary_size))

可以看到,它這里并沒有傳sampled_values,那么它的負樣本是怎么得到的呢?繼續(xù)看nce_loss的實現(xiàn),可以看到里面處理sampled_values=None的代碼如下:

    if sampled_values is None:
      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,
          num_true=num_true,
          num_sampled=num_sampled,
          unique=True,
          range_max=num_classes)

所以,默認情況下,他會用log_uniform_candidate_sampler去采樣。那么log_uniform_candidate_sampler是怎么采樣的呢?他的實現(xiàn)在這里

  • 他會在[0, range_max)中采樣出一個整數(shù)k
  • P(k) = (log(k + 2) - log(k + 1)) / log(range_max + 1)

可以看到,k越大,被采樣到的概率越小。那么在TF的word2vec里,類別的編號有什么含義嗎?看下面的代碼:

def build_dataset(words):
  count = [['UNK', -1]]
  count.extend(collections.Counter(words).most_common(vocabulary_size - 1))
  dictionary = dict()
  for word, _ in count:
    dictionary[word] = len(dictionary)
  data = list()
  unk_count = 0
  for word in words:
    if word in dictionary:
      index = dictionary[word]
    else:
      index = 0  # dictionary['UNK']
      unk_count += 1
    data.append(index)
  count[0][1] = unk_count
  reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
  return data, count, dictionary, reverse_dictionary

可以看到,TF的word2vec實現(xiàn)里,詞頻越大,詞的類別編號也就越大。因此,在TF的word2vec里,負采樣的過程其實就是優(yōu)先采詞頻高的詞作為負樣本。

在提出負采樣的原始論文中, 包括word2vec的原始C++實現(xiàn)中。是按照熱門度的0.75次方采樣的,這個和TF的實現(xiàn)有所區(qū)別。但大概的意思差不多,就是越熱門,越有可能成為負樣本。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容