Skip-gram源碼實(shí)現(xiàn)及解讀

??Skip-gram是一種用于訓(xùn)練詞向量的算法,它可以將每個(gè)單詞表示為一個(gè)向量,并且在這些向量之間保留單詞之間的語(yǔ)義關(guān)系,是一種通過(guò)中心詞預(yù)測(cè)周圍詞的神經(jīng)網(wǎng)絡(luò)算法。

import paddle
import paddle.nn as nn


class SkipGram(paddle.nn.Layer):

    def __init__(self, vocab_size, embedding_dim):
        super(SkipGram, self).__init__()

        # 定義兩個(gè)embedding層,一個(gè)用于輸入的單詞,另一個(gè)用于輸出的上下文單詞

        self.embedding_in = nn.Embedding(vocab_size, embedding_dim)

        self.embedding_out = nn.Embedding(vocab_size, embedding_dim)

        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, input_word, context_word):
        # 獲取輸入單詞和上下文單詞的嵌入向量
        input_embed = self.embedding_in(input_word)
        context_embed = self.embedding_out(context_word)

        # 可以加一個(gè)全鏈接層
        input_embed = self.linear(input_embed)

        context_embed = self.linear(context_embed)

        # 如果在Skip-Gram模型中加入一層全連接層,可以增加模型的表達(dá)能力和非線性能力,從而提高模型的性能和效果。全連接層可以將輸入的嵌入向量進(jìn)行更復(fù)雜的變換,從而得到更豐富的特征表示。
        # 具體來(lái)說(shuō),可以在輸入嵌入向量上加入一層全連接層,然后再通過(guò)softmax函數(shù)得到預(yù)測(cè)值。這樣可以使模型更加靈活,適應(yīng)更復(fù)雜的語(yǔ)義關(guān)系。但是,加入全連接層也會(huì)增加模型的復(fù)雜度和訓(xùn)練難度,需要更多的訓(xùn)練數(shù)據(jù)和計(jì)算資源來(lái)訓(xùn)練和調(diào)整模型。

        # 計(jì)算內(nèi)積得到預(yù)測(cè)值
        # 方法1
        score = paddle.mm(input_embed, context_embed.t())

        # 方法 2
        # score = paddle.matmul(input_embed, context_embed, transpose_y=True)
        # score = paddle.sum(score, axis=-1)
        return score


num_epochs = 10
str1 = "the quick brown fox jumps over the lazy dog"
# skip為2,也就是左右各兩個(gè)單詞
training_data = [
    (0, 1), (0, 2), (1, 0), (1, 2), (1, 3),

    (2, 0), (2, 1), (2, 3), (3, 1), (3, 2),

    (3, 4), (4, 3), (4, 5), (5, 4), (5, 6),

    (6, 5), (6, 7), (7, 6), (7, 8), (8, 7)
]

word2idx = {"the": 0, "quick": 1, "brown": 2, "fox": 3, "jumps": 4, "over": 5, "lazy": 6, "dog": 7}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(word2idx)

embedding_dim = 64

model = SkipGram(vocab_size + 1, embedding_dim)
criterion = nn.BCEWithLogitsLoss()
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

for epoch in range(num_epochs):
    total_loss = 0

    for input_word, context_word in training_data:
        # 將數(shù)據(jù)轉(zhuǎn)換成tensor
        input_word = paddle.to_tensor([input_word])
        context_word = paddle.to_tensor([context_word])

        # 將模型設(shè)置為訓(xùn)練模式
        model.train()
        # 前向傳播
        output = model(input_word, context_word)

        ones_label = paddle.ones_like(output)
        # 計(jì)算損失函數(shù)
        # loss = nn.functional.binary_cross_entropy_with_logits(output, ones_label)
        loss = criterion(output, ones_label)
        total_loss += loss.item()

        # 反向傳播和優(yōu)化
        optimizer.clear_grad()
        loss.backward()
        optimizer.step()
    # 打印每個(gè)epoch的平均損失
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, total_loss / len(training_data)))

# 計(jì)算詞向量并保存到文件中
embeddings = model.embedding_in.weight.numpy()
with open('embeddings.txt', 'w') as f:
    for i, word in idx2word.items():
        embedding = ' '.join(map(str, embeddings[i]))
        f.write('{} {}\n'.format(word, embedding))




損失值運(yùn)行結(jié)果如下:

Epoch [1/10], Loss: 0.6915
Epoch [2/10], Loss: 0.6712
Epoch [3/10], Loss: 0.6448
Epoch [4/10], Loss: 0.6096
Epoch [5/10], Loss: 0.5639
Epoch [6/10], Loss: 0.5069
Epoch [7/10], Loss: 0.4408
Epoch [8/10], Loss: 0.3709
Epoch [9/10], Loss: 0.3037
Epoch [10/10], Loss: 0.2444

在embeddings.txt包含每個(gè)單詞的embedding,格式如下:

the 0.23742944 -0.09895072 -0.11146876 0.2605418 0.024332121 0.0910439 0.12424937 -0.013771858 0.116671495 0.0015973783 -0.10863184 -0.10972429 0.07718096 -0.0033314745 0.2224411 -0.22004403 0.1281483 -0.12591755 0.14134666 -0.034466334 0.24389501 -0.07925096 0.10862582 -0.22061633 0.18360858 -0.17183 0.117620915 -0.23871568 -0.21196923 -0.014604413 0.040446073 0.17018412 -0.24544948 0.10585982 0.05756965 0.058975074 -0.2550219 0.2217722 -0.11203941 0.21279284 0.26438716 -0.17886016 -0.17222002 0.036797255 0.00933418 0.07391108 -0.20267555 -0.21875764 -0.30134645 0.25732276 -0.12506652 -0.060186304 -0.058356136 0.1225115 0.043293785 0.040848006 0.008795115 0.15603863 -0.23737802 -0.045909956 0.14689028 -0.01215158 0.2031173 0.101658516
quick 0.22985205 -0.28023568 0.17331894 -0.18404774 0.19435923 0.054511975 -0.12446486 -0.20461103 -0.20024063 0.074314184 -0.19651282 -0.15052138 -0.25369388 -0.0021391092 -0.2760222 0.10754039 0.11228328 -0.019922553 0.24608843 -0.2582981 -0.11957916 -0.18738061 0.018962713 -0.096384935 -0.26450405 0.066627055 -0.0071602613 -0.077308446 0.26354805 0.07547034 0.058478173 -0.19880083 -0.29015306 0.034329493 0.2207786 -0.11239037 0.049043965 -0.021390196 -0.004287906 -0.28705558 -0.1730856 -0.27100953 0.1121744 -0.25906146 -0.074053064 0.10330311 0.04657338 -0.119830996 -0.17361426 0.17114878 0.1927943 -0.2083592 -0.1774211 -0.2773358 -0.114716105 -0.011761455 -0.1675885 0.1555276 -0.15725754 -0.00861447 -0.27093074 -0.24180736 -0.18109317 0.27589953
brown 0.15149793 -0.19986486 0.2548086 0.020637682 -0.11013863 0.024790183 -0.02504396 0.037789762 -0.020729668 -0.23532745 0.28116202 -0.04157986 -0.29003182 0.29875976 0.16469309 0.23130749 0.17639601 -0.23869719 -0.13300861 0.27599373 0.02700885 0.05513569 0.26320535 -0.22142021 0.013878512 0.10758007 0.22711909 -0.18499781 0.070877045 0.079043075 -0.24289952 -0.2636248 0.0006990259 0.18134123 -0.023455022 -0.034577943 -0.25355765 0.29205313 0.23203316 0.04200985 -0.039580178 -0.21799651 0.20781282 0.083057314 -0.22915262 0.21067782 -0.21856064 0.16073515 0.10993917 -0.14174365 0.097185716 -0.17790347 0.18403171 0.012047063 0.20417404 0.05510201 0.135194 -0.0029973947 0.007548025 0.04317737 0.12034502 0.05921867 0.030197665 0.061334394
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容