Candidate Sampling

本文首發(fā)于:http://xzyin.top

轉(zhuǎn)載請注明出處:http://xzyin.top/candidate-sampling/

在這篇文章中主要介紹一下Candidate sampling在模型訓練中的使用。

作為一個菜雞的推薦煉丹師,前段時間看YouTube DNN的煉丹手冊和雙塔模型(DSSM)的配藥指南。發(fā)現(xiàn)在關于計算優(yōu)化的部分YouTube DNN和DSSM都用了importance sampling進負樣本的選取。

從網(wǎng)絡結(jié)構(gòu)的搭建和數(shù)據(jù)的選取來看,YouTube DNN怎么看都像是一個廣義的word2vec??墒窃趙ord2vec模型的TensorFlow實現(xiàn)里面用的是NCE做計算性能的采樣優(yōu)化。

此外,對于YouTube DNN的負樣本選取知乎上也有廣泛的討論。例如:知乎石塔西在《負樣本為王:評Facebook的向量化召回算法》提出的hard模式和easy模式假設。

為了弄清楚不同網(wǎng)絡采樣的細節(jié)和采樣的方法,我翻閱大量偏方,并且稍微做一下梳理。

在介紹candidate sampling之前我們先了解一下問題的背景。

1. Softmax 和 Cross Entropy

在多分類問題中,模型訓練的目的是在訓練集上學習到一個函數(shù)

F(x,y)
,該模型對于測試集和驗證集上每一個輸入
x
,能夠準確地預測到對應的類別
y
。

在一個類別數(shù)為

K
的多分類問題中,模型的softmax層對每一個類別計算可能的概率:

P(y_j|x) = \frac{\mathrm{exp}(h^\mathsf{T}v'_j)}{\sum^K_{i=1} \mathrm{exp}(h^\mathsf{T}v'_i) }

在softmax計算每一個輸入類別概率的基礎上,需要構(gòu)建損失函數(shù)

J
來評估模型訓練中學習得到
F(x,y)
的效果?;谧顦闼氐南敕?,我們希望預測的類別
y_{prediction}
和真實的類別
y_{label}
具有更為接近的分布。

在評估兩個分布的距離上,很自然地聯(lián)想到使用

KL
散度,那么對于損失函數(shù)的相關定義如下:

在模型訓練過程中,對應的輸入為:

x

樣本對應標簽的期望分布為
D(y|x)

模型
F(x,y)
預測出的類別分布為
P(y|x)

損失函數(shù):

J= H(D(y|x),P(y|x)) \\ = -D(y|x) \mathrm{ln}\frac{P(y|x)}{D(y|x)} \\ = D(y|x)\mathrm{ln}D(y|x) - D(y|x)\mathrm{ln}P(y|x)

在模型訓練過程中,對于相同數(shù)據(jù)集來說

D(y|x)\mathrm{ln}D(y|x)
可以看成是常數(shù)。

那么用來評估訓練結(jié)果的損失函數(shù)可以表示為:

\mathrm{min}\;(J) = \mathrm{min}\;(D(y|x)\mathrm{ln}D(y|x) - D(y|x)\mathrm{ln}P(y|x)) \\ = \mathrm{min}(\mathcal{K} - D(y|x)\mathrm{ln}P(y|x)) \\ \sim \mathrm{min}(- D(y|x)\mathrm{ln}P(y|x)) \\

其中,

\mathcal{K}
為常數(shù),那么損失函數(shù)的形式可以表示為期望分布
D(y|x)
和真實分布
P(y|x)
的交叉熵。

對所有

K
個類別求和,并且
D(y|x)
用標簽
y
的值表示得到損失函數(shù)形式如下:

J = -\sum_{i=1}^{K} y_i \mathrm{ln}(P(y_i|x)) \\

其中,當

i
對應的類別為正樣本時
y_i=1
,當
i
為負樣本時
y_i=0
,上述公式簡化為:

J = - \mathrm{ln}\;P(y_{pos}|x) \\ = - \mathrm{ln}\; \frac{\mathrm{exp}(h^\mathsf{T}v'_{pos})}{\sum^K_{i=1} \mathrm{exp}(h^\mathsf{T}v'_i) } \\

其中

y_{pos}
表示正樣本類別。

2. candidate sampling

在上面的章節(jié)中,我們討論了模型訓練在多分類問題當中的損失函數(shù)。那么如果在分類數(shù)

K
非常多的情況下,對于每個樣本分類的預測都需要計算
K
個類別的概率。

顯然,在分類數(shù)較小的情況下softmax的計算量可以接受,但是當分類數(shù)目擴增到百萬甚至千萬量級的情況下會單個樣本的計算量過大。

假設,模型訓練的數(shù)據(jù)集中有1000萬條樣本數(shù)據(jù),其中分類的個數(shù)為百萬量級。每個分類概率的計算為0.01ms
那么所需要的計算時間:

T_{cost} = 10000000 \times 1000000 \times 0.01 \div 1000 \ 3600 \ 24 = 1157.d

顯然這個計算量級是沒有辦法被接受的。

那么基本思想就是在怎么樣不影響計算效果的前提下減小計算量。

針對這個問題目前具備兩種方法:

  • softmax-based approach: 基于樹結(jié)構(gòu)的分層softmax,減少損失函數(shù)計算過程中計算量。
  • sampling-based approach: 通過用采樣的方式,通過計算樣本的損失來代替全量的樣本計算。

這里我們主要介紹sampling-based approach的方法,也就是candidate sampling。

在做具體討論之前,我們先回到損失函數(shù)的計算公式,并做進一步的簡化:

J = - \mathrm{ln}\; \frac{\mathrm{exp}(h^\mathsf{T}v'_{pos})}{\sum^K_{i=1} \mathrm{exp}(h^\mathsf{T}v'_i) } \\ = - h^\mathsf{T}v'_{pos} + \mathrm{ln}\sum^K_{i=1} \mathrm{exp}(h^\mathsf{T}v'_i)

在這個公式中用

\xi(w) = - h^\mathsf{T}v'_{pos}
,簡化為:

J = \xi (w_{pos}) + \mathrm{ln}\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))

對損失函數(shù)求導并計算梯度

\nabla_\theta J = \nabla_\theta \xi (w_{pos}) + \nabla_\theta \mathrm{ln}\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))

因為

log(x)
的梯度為
\frac{1}{x}

\nabla_\theta J = \nabla_\theta \xi (w_{pos}) + \frac{1}{\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))} \nabla_\theta\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))

然后我們把求導符號放到累加符內(nèi)得到:

\nabla_\theta J = \nabla_\theta \xi (w_{pos}) + \frac{1}{\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))} \sum^K_{i=1} \nabla_\theta \mathrm{exp}(-\xi(w_i))

并且有

\nabla_x \mathrm{exp}(x) = exp(x)
那么:

\nabla_\theta J = \nabla_\theta \xi (w_{pos}) + \frac{1}{\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))} \sum^K_{i=1} \mathrm{exp}(-\xi(w_i)) \nabla_\theta (-\xi(w_i)

上面的公式可以重寫成:

\nabla_\theta J = \nabla_\theta \xi (w_{pos}) + \sum^K_{i=1} \frac{\mathrm{exp}(-\xi(w_i))}{\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))} \nabla_\theta (-\xi(w_i)

其中

\frac{\mathrm{exp}(-\xi(w_i))}{\sum^K_{i=1} \mathrm{exp}(-\xi(w_i))}
就是輸入上下文
c
在類別
i
上的的概率
P(w_i | c)

最終要計算的梯度形式如下:

\nabla_\theta J = \nabla_\theta \xi (w_{pos}) + \sum_{i=1}^K P(w_i|c) \nabla_\theta (-\xi(w_i) \\ = \nabla_\theta \xi (w_{pos}) - \sum_{i=1}^K P(w_i|c) \nabla_\theta \xi(w_i)

根據(jù)最終的公式,我們可以將梯度的計算分為兩部分:

  • \nabla_\theta \xi (w_{pos})
    : 是參數(shù)關于正樣本
    y_{pos}
    的梯度,可以理解為對目標詞的正面優(yōu)化。

  • -\sum_{i=1}^K P(w_i\|c)\nabla_\theta \xi(w_i)
    : 是所有樣本概率對應梯度的累加和,可以理解為對其他詞匯的負向優(yōu)化。

在基于采樣的優(yōu)化當中,我們不需要計算所有類別的累加,只需要通過采樣求到

\nabla_\theta \xi(w_i)
在分布
P(w_i|c)
的期望即可。

那么:

\sum_{i=1}^K P(w_i|c)\nabla_\theta \xi(w_i) = \mathbb{E}_{w_i\sim P} [\nabla_\theta \xi(w_i)] \\

那么接下來的問題就變成了如何準確的計算梯度在概率分布

P(w_i)
上的期望:

\mathbb{E}_{w_i\sim P} [\nabla_\theta \xi(w_i)]

3. 常見的candidate sampling方法

在了解了candidate sampling方法的基本思想之后,我們怎么樣計算期望

\mathbb{E_{w_i\sim P(w_i)}\nabla_\theta \xi(w_i)}

成為一個值得考慮的問題。

3.1 Importance Sampling

對于任何概率分布我們計算期望

\mathbb{E}
的時候,可以采用蒙特卡洛方法,根據(jù)分布隨機采樣出一系列樣本,然后計算樣本的平均值。

對于上述的例子,如果我們知道模型在不同類別的概率分布

P(w_i)
,在計算期望的時候可以直接采樣出
m
個類別
w_1,...,w_m

并且計算期望:

\mathbb{E}_{w_i \sim P}[\nabla_\theta\xi(w_i)] \approx \frac{1}{m}\sum_i^m \nabla_\theta\xi(w_i)

但是為了從分布

P
中采樣樣本,我們首先需要計算分布
P
。可是candidate sampling的目的就是為了避免計算分布
P
。
為了解決這個問題,能夠使用的基本方法是重要性采樣:

重要性采樣(importance sampling)算法

假設我們需要計算概率密度函數(shù)

h(x)
\pi(x)
上的期望

\mu = \mathbb{E}_\pi{h(x)} = \int h(x)\pi(x)

那么重要性采樣算法對應的形式如下:

(a) 首先,從分布

g(\cdot)
中隨機采樣出
m
個樣本
\mathrm{x}_1,...,\mathrm{x}_m

(b) 計算重要性權重:

r(\mathrm{x}_i) = \frac{\pi(\mathrm{x}_i)}{g(\mathrm{x}_i)}, for\;\;j=1,...,m

(c) 近似期望

\hat \mu

\hat u=\frac{r_1 h(\mathrm{x}_1)+...+r_m h(\mathrm{x}_m)}{r(\mathrm{x}_1) + ... + r(\mathrm{x}_m)}

為了使得估計的時候誤差更小,我們需要盡可能地使得

g(\cdot)
接近原來的
\pi(\mathrm{x})
。

這個時候上述公式可以描述為:

\hat \mu = \frac{1}{m} \{r(\mathrm{x}_1) h(\mathrm{x}_1) + ... + r(\mathrm{x}_m) h(\mathrm{x}_m)\}

根據(jù)上述描述,我們先預設一個分布

Q(w)
,為了使得
Q(w)
盡可能接近
P(w)
,一般可以采樣一元分布。

對應的重要性權重

r(w) = \frac{\mathrm{exp}(-\xi(w))}{Q(w)}
,那么對應的期望計算公式如下:

\mathbb{E}_{w_i \sim P} \approx \frac{r(w_1) \nabla_\theta \xi(w_1) + ... + r(w_m) \nabla_\theta \xi(w_m)}{r(w_i) +...+r(w_m)} \\ = \frac{\sum_{i=1}^m r(w_i) \nabla_\theta \xi(w_i)}{\sum_{i=1}^m r(w_i)}

R = \sum_{i=1}^m r(w_i)
得到

\mathbb{E}_{w_i \sim P} \approx \frac{1}{R} \sum_{i=1}^m r(w_i) \nabla_\theta \xi(w_i)

3.2 Noise Contrastive Estimation

在上面介紹完成Importance Sampling之后,我們來看一下Noise Contrastive Estimation(NCE)。拋開上面通過采樣的思想,利用importance sampling近似計算多分類問題softmax損失的方法。

在NCE中,完全推翻上述方法并從試圖從另外一個角度來解決多分類問題loss計算的問題——我們能否找到一個損失函數(shù)用于替代原來的損失計算,從而避免softmax中歸一化因子的計算。

NCE的基本思想是將多分類問題轉(zhuǎn)換成為二分類問題,從噪音分布中采樣,減少優(yōu)化過程的計算復雜度。

在采樣NCE方式計算loss的過程中,我們引入噪音分布

Q(w)
。這個噪音分布可以跟語境有關,也可以跟語境無關。在噪音分布和語境無關的情況下,我們設置噪音分布的強度是真實數(shù)據(jù)分布的
m
倍。

那么對于訓練數(shù)據(jù)

(c,w)
可以得到真實分布和噪音分布的概率:

P(y=1|w,c) = \frac{P_{train}(w|c)}{P_{train}(w|c) + mQ(w|c)}\\ P(y=0|w,c) = \frac{mQ(w|c)}{P_{train}(w|c) + mQ(w|c)}

得到

P(w|c) = P_{train}(w|c) + mQ(w|c)

在原來的推導中:

P(w|c) = \frac{\mathrm{exp}(h^\mathrm{T} v'_{w})}{\sum_{i=1}^K \mathrm{exp}(h^\mathrm{T} v'_{w_i})}

在NCE中為了避免對分母部分歸一化因子的計算,將歸一化因子表示為一個學習的參數(shù)

Z(c)

Z(c) = \sum_{i=1}^K \mathrm{exp}(h^\mathsf{T} v_{w'_i})

這個時候簡化為:

P(w|c) = \mathrm{exp}(h^\mathsf{T} v'_{w})

那么對于這個二分類問題計算Logistic regression損失:

J = [ln \frac{\mathrm{exp}(h^\mathsf{T} v'_{w_i})}{\mathrm{exp}(h^\mathsf{T} v'_{w_i}) + mQ(w_i)}] + \sum_{j=1}^m [ln(1-ln \frac{\mathrm{exp}(h^\mathsf{T} v'_{w_{i,j}})}{\mathrm{exp}(h^\mathsf{T} v'_{w_{i,j}}) + mQ(w_{i,j})})]

在上述公式中,當

m\rightarrow \infty
, 上述公式和softmax的損失函數(shù)相似。

從NCE采樣方法中可知:

  • 基于softmax的多分類問題的損失函數(shù)可以表示成為logistic regression二分類的形式。
  • NCE方法中,在梯度更新中放棄了對負樣本參數(shù)的更新。

4. Tensorflow中candidate sampling的實現(xiàn)

理論很豐滿,落地很骨感。

在了解完candidate sampling中的Importance sampling和Noise Contrastive Estimation的原理之后如果要工程落地還是需要依賴可用的計算框架。在TensorFlow中就實現(xiàn)了這兩個方法對應可以調(diào)用的API分別是:

  • importance sampling: tf.nn.sampled_softmax_loss()
  • Noise Contrastive Estimation: tf.nn.nce_loss()

4.1 tf.nn.sampled_softmax_loss()

sampled_softmax_loss()中包含了兩部分內(nèi)容。

  1. _compute_sampled_logits()
  2. softmax_cross_entropy_with_logits_v2()

_compute_sampled_logits() 主要進行采樣并計算logit。

softmax_cross_entropy_with_logits_v2() 主要計算softmax的交叉熵損失。
接下來我們主要看一下_compute_sampled_logits()的源碼。

4.2 tf.nn.nce_loss()

nce_loss()中包含了兩部分內(nèi)容。

  1. _compute_sampled_logits()
  2. sigmoid_cross_entropy_with_logits()

_compute_sampled_logits() 主要進行采樣并計算logit。

sigmoid_cross_entropy_with_logits() 主要計算sigmoid的交叉熵損失。

接下來我們主要看一下_compute_sampled_logits()的源碼。

def _compute_sampled_logits(weights,
                            biases,
                            labels,
                            inputs,
                            num_sampled,
                            num_classes,
                            num_true=1,
                            sampled_values=None,
                            subtract_log_q=True,
                            remove_accidental_hits=False,
                            partition_strategy="mod",
                            name=None,
                            seed=None):

  if isinstance(weights, variables.PartitionedVariable):
    weights = list(weights)
  if not isinstance(weights, list):
    weights = [weights]

  with ops.name_scope(name, "compute_sampled_logits",
                      weights + [biases, inputs, labels]):
    if labels.dtype != dtypes.int64:
      labels = math_ops.cast(labels, dtypes.int64)
    labels_flat = array_ops.reshape(labels, [-1])

    # Sample the negative labels.
    #   sampled shape: [num_sampled] tensor
    #   true_expected_count shape = [batch_size, 1] tensor
    #   sampled_expected_count shape = [num_sampled] tensor
    if sampled_values is None:
      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,
          num_true=num_true,
          num_sampled=num_sampled,
          unique=True,
          range_max=num_classes,
          seed=seed)
    # NOTE: pylint cannot tell that 'sampled_values' is a sequence
    # pylint: disable=unpacking-non-sequence
    sampled, true_expected_count, sampled_expected_count = (
        array_ops.stop_gradient(s) for s in sampled_values)
    # pylint: enable=unpacking-non-sequence
    sampled = math_ops.cast(sampled, dtypes.int64)

    # labels_flat is a [batch_size * num_true] tensor
    # sampled is a [num_sampled] int tensor
    all_ids = array_ops.concat([labels_flat, sampled], 0)

    # Retrieve the true weights and the logits of the sampled weights.

    # weights shape is [num_classes, dim]
    all_w = embedding_ops.embedding_lookup(
        weights, all_ids, partition_strategy=partition_strategy)
    if all_w.dtype != inputs.dtype:
      all_w = math_ops.cast(all_w, inputs.dtype)

    # true_w shape is [batch_size * num_true, dim]
    true_w = array_ops.slice(all_w, [0, 0],
                             array_ops.stack(
                                 [array_ops.shape(labels_flat)[0], -1]))

    sampled_w = array_ops.slice(
        all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
    # inputs has shape [batch_size, dim]
    # sampled_w has shape [num_sampled, dim]
    # Apply X*W', which yields [batch_size, num_sampled]
    sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)

    # Retrieve the true and sampled biases, compute the true logits, and
    # add the biases to the true and sampled logits.
    all_b = embedding_ops.embedding_lookup(
        biases, all_ids, partition_strategy=partition_strategy)
    if all_b.dtype != inputs.dtype:
      all_b = math_ops.cast(all_b, inputs.dtype)
    # true_b is a [batch_size * num_true] tensor
    # sampled_b is a [num_sampled] float tensor
    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])

    # inputs shape is [batch_size, dim]
    # true_w shape is [batch_size * num_true, dim]
    # row_wise_dots is [batch_size, num_true, dim]
    dim = array_ops.shape(true_w)[1:2]
    new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
    row_wise_dots = math_ops.multiply(
        array_ops.expand_dims(inputs, 1),
        array_ops.reshape(true_w, new_true_w_shape))
    # We want the row-wise dot plus biases which yields a
    # [batch_size, num_true] tensor of true_logits.
    dots_as_matrix = array_ops.reshape(row_wise_dots,
                                       array_ops.concat([[-1], dim], 0))
    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
    true_b = array_ops.reshape(true_b, [-1, num_true])
    true_logits += true_b
    sampled_logits += sampled_b

    if remove_accidental_hits:
      acc_hits = candidate_sampling_ops.compute_accidental_hits(
          labels, sampled, num_true=num_true)
      acc_indices, acc_ids, acc_weights = acc_hits

      # This is how SparseToDense expects the indices.
      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
      acc_ids_2d_int32 = array_ops.reshape(
          math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
      sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,
                                        "sparse_indices")
      # Create sampled_logits_shape = [batch_size, num_sampled]
      sampled_logits_shape = array_ops.concat(
          [array_ops.shape(labels)[:1],
           array_ops.expand_dims(num_sampled, 0)], 0)
      if sampled_logits.dtype != acc_weights.dtype:
        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
      sampled_logits += gen_sparse_ops.sparse_to_dense(
          sparse_indices,
          sampled_logits_shape,
          acc_weights,
          default_value=0.0,
          validate_indices=False)

    if subtract_log_q:
      # Subtract log of Q(l), prior probability that l appears in sampled.
      true_logits -= math_ops.log(true_expected_count)
      sampled_logits -= math_ops.log(sampled_expected_count)

    # Construct output logits and labels. The true labels/logits start at col 0.
    out_logits = array_ops.concat([true_logits, sampled_logits], 1)

    # true_logits is a float tensor, ones_like(true_logits) is a float
    # tensor of ones. We then divide by num_true to ensure the per-example
    # labels sum to 1.0, i.e. form a proper probability distribution.
    out_labels = array_ops.concat([
        array_ops.ones_like(true_logits) / num_true,
        array_ops.zeros_like(sampled_logits)
    ], 1)

    return out_logits, out_labels

參考資料

[1] 從最優(yōu)化的角度看待Softmax損失函數(shù)

[2] On word embeddings - Part 2: Approximating the Softmax

[3] 重要性采樣

[4] Noise Contrastive Estimation

?著作權歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

友情鏈接更多精彩內(nèi)容