keras實(shí)現(xiàn)Attention機(jī)制

attention層的定義:(思路參考https://github.com/philipperemy/keras-attention-mechanism

# Attention GRU network       
class AttLayer(Layer):
    def __init__(self, **kwargs):
        self.init = initializations.get('normal')
        #self.input_spec = [InputSpec(ndim=3)]
        super(AttLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape)==3
        #self.W = self.init((input_shape[-1],1))
        self.W = self.init((input_shape[-1],))
        #self.input_spec = [InputSpec(shape=input_shape)]
        self.trainable_weights = [self.W]
        super(AttLayer, self).build(input_shape)  # be sure you call this somewhere!

    def call(self, x, mask=None):
        eij = K.tanh(K.dot(x, self.W))
        
        ai = K.exp(eij)
        weights = ai/K.sum(ai, axis=1).dimshuffle(0,'x')
        
        weighted_input = x*weights.dimshuffle(0,1,'x')
        return weighted_input.sum(axis=1)

    def get_output_shape_for(self, input_shape):
        return (input_shape[0], input_shape[-1])

具體的用法:

input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(input)
l_lstm = Bidirectional(LSTM(100, return_sequences=True))(embedded_sequences)
l_att = AttLayer()(l_lstm)
preds = Dense(2, activation='softmax')(l_att)
model = Model(sequence_input, preds)
model.compile(loss='categorical_crossentropy',
             optimizer='rmsprop',
             metrics=['acc'])

print("model fitting - attention GRU network")
model.summary()
model.fit(x_train, y_train, validation_data=(x_val, y_val),
         nb_epoch=10, batch_size=50)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • afinalAfinal是一個(gè)android的ioc,orm框架 https://github.com/yangf...
    wgl0419閱讀 6,598評(píng)論 1 9
  • afinalAfinal是一個(gè)android的ioc,orm框架 https://github.com/yangf...
    passiontim閱讀 15,889評(píng)論 2 45
  • Android 自定義View的各種姿勢(shì)1 Activity的顯示之ViewRootImpl詳解 Activity...
    passiontim閱讀 179,171評(píng)論 25 708
  • 大千世界里, 總會(huì)有人看不慣你, 茫茫人海中, 總會(huì)有自己的晴空。 不管這個(gè)社會(huì)變化多快, 不管周?chē)娜诵挠卸鄰?fù)雜...
    飛揚(yáng)的柳絮閱讀 382評(píng)論 0 0
  • 你愛(ài)我嗎? 愛(ài)。 為什么愛(ài)? 沒(méi)有理由。 我脾氣大,愛(ài)耍小性子,經(jīng)常沖你發(fā)脾氣,也不會(huì)做家務(wù),長(zhǎng)得也不好看…… 好...
    寶菇?jīng)鱿壬?/span>閱讀 432評(píng)論 0 0

友情鏈接更多精彩內(nèi)容