本文首發(fā)于:http://xzyin.top
轉(zhuǎn)載請注明出處:http://xzyin.top/candidate-sampling/
在這篇文章中主要介紹一下Candidate sampling在模型訓練中的使用。
作為一個菜雞的推薦煉丹師,前段時間看YouTube DNN的煉丹手冊和雙塔模型(DSSM)的配藥指南。發(fā)現(xiàn)在關于計算優(yōu)化的部分YouTube DNN和DSSM都用了importance sampling進負樣本的選取。
從網(wǎng)絡結(jié)構(gòu)的搭建和數(shù)據(jù)的選取來看,YouTube DNN怎么看都像是一個廣義的word2vec??墒窃趙ord2vec模型的TensorFlow實現(xiàn)里面用的是NCE做計算性能的采樣優(yōu)化。
此外,對于YouTube DNN的負樣本選取知乎上也有廣泛的討論。例如:知乎石塔西在《負樣本為王:評Facebook的向量化召回算法》提出的hard模式和easy模式假設。
為了弄清楚不同網(wǎng)絡采樣的細節(jié)和采樣的方法,我翻閱大量偏方,并且稍微做一下梳理。
在介紹candidate sampling之前我們先了解一下問題的背景。
1. Softmax 和 Cross Entropy
在多分類問題中,模型訓練的目的是在訓練集上學習到一個函數(shù)
在一個類別數(shù)為
在softmax計算每一個輸入類別概率的基礎上,需要構(gòu)建損失函數(shù)
在評估兩個分布的距離上,很自然地聯(lián)想到使用
在模型訓練過程中,對應的輸入為:
樣本對應標簽的期望分布為
模型
損失函數(shù):
在模型訓練過程中,對于相同數(shù)據(jù)集來說
那么用來評估訓練結(jié)果的損失函數(shù)可以表示為:
其中,
對所有
其中,當
其中
2. candidate sampling
在上面的章節(jié)中,我們討論了模型訓練在多分類問題當中的損失函數(shù)。那么如果在分類數(shù)
顯然,在分類數(shù)較小的情況下softmax的計算量可以接受,但是當分類數(shù)目擴增到百萬甚至千萬量級的情況下會單個樣本的計算量過大。
假設,模型訓練的數(shù)據(jù)集中有1000萬條樣本數(shù)據(jù),其中分類的個數(shù)為百萬量級。每個分類概率的計算為0.01ms
那么所需要的計算時間:![]()
顯然這個計算量級是沒有辦法被接受的。
那么基本思想就是在怎么樣不影響計算效果的前提下減小計算量。
針對這個問題目前具備兩種方法:
- softmax-based approach: 基于樹結(jié)構(gòu)的分層softmax,減少損失函數(shù)計算過程中計算量。
- sampling-based approach: 通過用采樣的方式,通過計算樣本的損失來代替全量的樣本計算。
這里我們主要介紹sampling-based approach的方法,也就是candidate sampling。
在做具體討論之前,我們先回到損失函數(shù)的計算公式,并做進一步的簡化:
在這個公式中用
對損失函數(shù)求導并計算梯度
因為
然后我們把求導符號放到累加符內(nèi)得到:
并且有
上面的公式可以重寫成:
其中
最終要計算的梯度形式如下:
根據(jù)最終的公式,我們可以將梯度的計算分為兩部分:
- : 是參數(shù)關于正樣本的梯度,可以理解為對目標詞的正面優(yōu)化。
- : 是所有樣本概率對應梯度的累加和,可以理解為對其他詞匯的負向優(yōu)化。
在基于采樣的優(yōu)化當中,我們不需要計算所有類別的累加,只需要通過采樣求到
那么:
那么接下來的問題就變成了如何準確的計算梯度在概率分布
3. 常見的candidate sampling方法
在了解了candidate sampling方法的基本思想之后,我們怎么樣計算期望
成為一個值得考慮的問題。
3.1 Importance Sampling
對于任何概率分布我們計算期望
對于上述的例子,如果我們知道模型在不同類別的概率分布
并且計算期望:
但是為了從分布
為了解決這個問題,能夠使用的基本方法是重要性采樣:
重要性采樣(importance sampling)算法
假設我們需要計算概率密度函數(shù)
在
上的期望![]()
![]()
那么重要性采樣算法對應的形式如下:
(a) 首先,從分布
中隨機采樣出
個樣本![]()
![]()
(b) 計算重要性權重:
![]()
(c) 近似期望
![]()
![]()
為了使得估計的時候誤差更小,我們需要盡可能地使得
接近原來的
。![]()
這個時候上述公式可以描述為:
![]()
根據(jù)上述描述,我們先預設一個分布
對應的重要性權重
令
3.2 Noise Contrastive Estimation
在上面介紹完成Importance Sampling之后,我們來看一下Noise Contrastive Estimation(NCE)。拋開上面通過采樣的思想,利用importance sampling近似計算多分類問題softmax損失的方法。
在NCE中,完全推翻上述方法并從試圖從另外一個角度來解決多分類問題loss計算的問題——我們能否找到一個損失函數(shù)用于替代原來的損失計算,從而避免softmax中歸一化因子的計算。
NCE的基本思想是將多分類問題轉(zhuǎn)換成為二分類問題,從噪音分布中采樣,減少優(yōu)化過程的計算復雜度。
在采樣NCE方式計算loss的過程中,我們引入噪音分布
那么對于訓練數(shù)據(jù)
得到
在原來的推導中:
在NCE中為了避免對分母部分歸一化因子的計算,將歸一化因子表示為一個學習的參數(shù)
這個時候簡化為:
那么對于這個二分類問題計算Logistic regression損失:
在上述公式中,當
從NCE采樣方法中可知:
- 基于softmax的多分類問題的損失函數(shù)可以表示成為logistic regression二分類的形式。
- NCE方法中,在梯度更新中放棄了對負樣本參數(shù)的更新。
4. Tensorflow中candidate sampling的實現(xiàn)
理論很豐滿,落地很骨感。
在了解完candidate sampling中的Importance sampling和Noise Contrastive Estimation的原理之后如果要工程落地還是需要依賴可用的計算框架。在TensorFlow中就實現(xiàn)了這兩個方法對應可以調(diào)用的API分別是:
- importance sampling: tf.nn.sampled_softmax_loss()
- Noise Contrastive Estimation: tf.nn.nce_loss()
4.1 tf.nn.sampled_softmax_loss()
在sampled_softmax_loss()中包含了兩部分內(nèi)容。
- _compute_sampled_logits()
- softmax_cross_entropy_with_logits_v2()
_compute_sampled_logits() 主要進行采樣并計算logit。
softmax_cross_entropy_with_logits_v2() 主要計算softmax的交叉熵損失。
接下來我們主要看一下_compute_sampled_logits()的源碼。
4.2 tf.nn.nce_loss()
在nce_loss()中包含了兩部分內(nèi)容。
- _compute_sampled_logits()
- sigmoid_cross_entropy_with_logits()
_compute_sampled_logits() 主要進行采樣并計算logit。
sigmoid_cross_entropy_with_logits() 主要計算sigmoid的交叉熵損失。
接下來我們主要看一下_compute_sampled_logits()的源碼。
def _compute_sampled_logits(weights,
biases,
labels,
inputs,
num_sampled,
num_classes,
num_true=1,
sampled_values=None,
subtract_log_q=True,
remove_accidental_hits=False,
partition_strategy="mod",
name=None,
seed=None):
if isinstance(weights, variables.PartitionedVariable):
weights = list(weights)
if not isinstance(weights, list):
weights = [weights]
with ops.name_scope(name, "compute_sampled_logits",
weights + [biases, inputs, labels]):
if labels.dtype != dtypes.int64:
labels = math_ops.cast(labels, dtypes.int64)
labels_flat = array_ops.reshape(labels, [-1])
# Sample the negative labels.
# sampled shape: [num_sampled] tensor
# true_expected_count shape = [batch_size, 1] tensor
# sampled_expected_count shape = [num_sampled] tensor
if sampled_values is None:
sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
true_classes=labels,
num_true=num_true,
num_sampled=num_sampled,
unique=True,
range_max=num_classes,
seed=seed)
# NOTE: pylint cannot tell that 'sampled_values' is a sequence
# pylint: disable=unpacking-non-sequence
sampled, true_expected_count, sampled_expected_count = (
array_ops.stop_gradient(s) for s in sampled_values)
# pylint: enable=unpacking-non-sequence
sampled = math_ops.cast(sampled, dtypes.int64)
# labels_flat is a [batch_size * num_true] tensor
# sampled is a [num_sampled] int tensor
all_ids = array_ops.concat([labels_flat, sampled], 0)
# Retrieve the true weights and the logits of the sampled weights.
# weights shape is [num_classes, dim]
all_w = embedding_ops.embedding_lookup(
weights, all_ids, partition_strategy=partition_strategy)
if all_w.dtype != inputs.dtype:
all_w = math_ops.cast(all_w, inputs.dtype)
# true_w shape is [batch_size * num_true, dim]
true_w = array_ops.slice(all_w, [0, 0],
array_ops.stack(
[array_ops.shape(labels_flat)[0], -1]))
sampled_w = array_ops.slice(
all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
# inputs has shape [batch_size, dim]
# sampled_w has shape [num_sampled, dim]
# Apply X*W', which yields [batch_size, num_sampled]
sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)
# Retrieve the true and sampled biases, compute the true logits, and
# add the biases to the true and sampled logits.
all_b = embedding_ops.embedding_lookup(
biases, all_ids, partition_strategy=partition_strategy)
if all_b.dtype != inputs.dtype:
all_b = math_ops.cast(all_b, inputs.dtype)
# true_b is a [batch_size * num_true] tensor
# sampled_b is a [num_sampled] float tensor
true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
# inputs shape is [batch_size, dim]
# true_w shape is [batch_size * num_true, dim]
# row_wise_dots is [batch_size, num_true, dim]
dim = array_ops.shape(true_w)[1:2]
new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
row_wise_dots = math_ops.multiply(
array_ops.expand_dims(inputs, 1),
array_ops.reshape(true_w, new_true_w_shape))
# We want the row-wise dot plus biases which yields a
# [batch_size, num_true] tensor of true_logits.
dots_as_matrix = array_ops.reshape(row_wise_dots,
array_ops.concat([[-1], dim], 0))
true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
true_b = array_ops.reshape(true_b, [-1, num_true])
true_logits += true_b
sampled_logits += sampled_b
if remove_accidental_hits:
acc_hits = candidate_sampling_ops.compute_accidental_hits(
labels, sampled, num_true=num_true)
acc_indices, acc_ids, acc_weights = acc_hits
# This is how SparseToDense expects the indices.
acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
acc_ids_2d_int32 = array_ops.reshape(
math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,
"sparse_indices")
# Create sampled_logits_shape = [batch_size, num_sampled]
sampled_logits_shape = array_ops.concat(
[array_ops.shape(labels)[:1],
array_ops.expand_dims(num_sampled, 0)], 0)
if sampled_logits.dtype != acc_weights.dtype:
acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
sampled_logits += gen_sparse_ops.sparse_to_dense(
sparse_indices,
sampled_logits_shape,
acc_weights,
default_value=0.0,
validate_indices=False)
if subtract_log_q:
# Subtract log of Q(l), prior probability that l appears in sampled.
true_logits -= math_ops.log(true_expected_count)
sampled_logits -= math_ops.log(sampled_expected_count)
# Construct output logits and labels. The true labels/logits start at col 0.
out_logits = array_ops.concat([true_logits, sampled_logits], 1)
# true_logits is a float tensor, ones_like(true_logits) is a float
# tensor of ones. We then divide by num_true to ensure the per-example
# labels sum to 1.0, i.e. form a proper probability distribution.
out_labels = array_ops.concat([
array_ops.ones_like(true_logits) / num_true,
array_ops.zeros_like(sampled_logits)
], 1)
return out_logits, out_labels
參考資料
[1] 從最優(yōu)化的角度看待Softmax損失函數(shù)