離散特征和連續(xù)特征同時(shí)存在,同時(shí)解決回歸和分類的問(wèn)題

有些同學(xué)對(duì)于mxnet的自定義Iter不是很熟悉,對(duì)多輸出也不熟悉,因此我用一個(gè)比較復(fù)雜的例子來(lái)說(shuō)明這個(gè)問(wèn)題:

1. 特征中有連續(xù)特征和離散特征
2. 同時(shí)要解決回歸問(wèn)題和分類問(wèn)題

本著End-to-End的精神,我們不做特征工程,當(dāng)然也就不能做離散化。于是,連續(xù)特征可以直接作為輸入,而離散特征則通過(guò)Embeding的方式輸入。如果要同時(shí)解決回歸和分類問(wèn)題,我們就需要兩個(gè)Loss層。

我們虛構(gòu)一個(gè)簡(jiǎn)單的二手車價(jià)格預(yù)估的問(wèn)題。我們假設(shè)一輛車的價(jià)格只取決于兩個(gè)因素,一個(gè)是車的品牌,一個(gè)是車的里程。不同品牌的車有不同的出廠價(jià)格,而車的行駛里程越長(zhǎng),價(jià)格就會(huì)越低。因此我們可以基于這個(gè)假設(shè),用如下的代碼構(gòu)造一個(gè)數(shù)據(jù)集:

#我們虛構(gòu)了201個(gè)不同的品牌,并給每個(gè)品牌設(shè)置一個(gè)出場(chǎng)價(jià)格
series = [1 + i for i in range(100)] + [101 - i for i in range(100)]

for i in range(10000):
    k = random.randint(0, 199)
    #越貴的品牌,我們認(rèn)為在數(shù)據(jù)集里出現(xiàn)的次數(shù)越少,因?yàn)樗u的少
    count = 1000 / series[k]
    for j in range(count):
        dis = random.random() * 10
        #實(shí)際的價(jià)格是品牌的出場(chǎng)價(jià)除以里程數(shù)的開(kāi)方
        price = series[k] / math.sqrt(1.0 + dis)
        print str(price) + '\t' + str(dis) + '\t' + str(k)

這里,車的品牌是一個(gè)離散特征,而里程是個(gè)連續(xù)的特征。問(wèn)題的目標(biāo)是,給定品牌和里程,同時(shí)預(yù)測(cè)車的價(jià)格(回歸問(wèn)題),以及車的價(jià)格區(qū)間(分類問(wèn)題)。我們用如下的網(wǎng)絡(luò)來(lái)解決這個(gè)問(wèn)題:

# dis 是輸入的里程
dis = mx.symbol.Variable('dis')
# price 是要預(yù)測(cè)的目標(biāo)價(jià)格
price = mx.symbol.Variable('price')
# price_interval 是要預(yù)測(cè)的價(jià)格區(qū)間
price_interval = mx.symbol.Variable('price_interval')
# series 是輸入的車的品牌
series = mx.symbol.Variable('series')

dis = mx.symbol.Flatten(data = dis, name = "dis_flatten")
series = mx.symbol.Embedding(data = series, input_dim = 200,
                             output_dim = 100, name = "series_embed")
series = mx.symbol.Flatten(series, name = "series_flatten")

net = mx.symbol.Concat(*[dis, series], dim = 1, name = "concat")
net = mx.symbol.FullyConnected(data = net, num_hidden = 100, name = "fc1")
net = mx.symbol.Activation(data = net, act_type="relu")
net = mx.symbol.FullyConnected(data = net, num_hidden = 100, name = "fc2")
net = mx.symbol.Activation(data = net, act_type="relu")
net = mx.symbol.FullyConnected(data = net, num_hidden = 1, name = "fc3")
# 這里最后為什么用relu呢?是因?yàn)閮r(jià)格一定是個(gè)正數(shù)
net = mx.symbol.Activation(data = net, act_type="relu")
net = mx.symbol.LinearRegressionOutput(data = net, label = price, name = "lro")

net2 = mx.symbol.Concat(*[dis, series], dim = 1, name = "concat")
net2 = mx.symbol.FullyConnected(data = net2, num_hidden = 100, name = "fc21")
net2 = mx.symbol.Activation(data = net2, act_type="relu")
net2 = mx.symbol.FullyConnected(data = net2, num_hidden = 100, name = "fc22")
net2 = mx.symbol.Activation(data = net2, act_type="relu")
net2 = mx.symbol.FullyConnected(data = net2, num_hidden = 8, name = "fc23")
net2 = mx.symbol.Activation(data = net2, act_type="relu")
net2 = mx.symbol.SoftmaxOutput(data = net2, label = price_interval, name="sf")

# 這里net預(yù)測(cè)price,net2預(yù)測(cè)price_interval, 最后group在一起返回
return mx.symbol.Group([net, net2])

這個(gè)例子里,我們需要同時(shí)提供dis, series, price, price_interval 四個(gè)變量。常見(jiàn)的Iter似乎不支持這個(gè)功能,因此可以自己實(shí)現(xiàn)一個(gè):

class PriceIter(mx.io.DataIter):
    def __init__(self, fname, batch_size):
        super(PriceIter, self).__init__()
        self.batch_size = batch_size
        self.dis = []
        self.series = []
        self.price = []
        # 這里預(yù)先從文件讀入所有的數(shù)據(jù)存下來(lái)
        for line in file(fname):
            price, d, s = line.strip().split("\t")
            self.price.append(float(price))
            self.series.append(np.array([int(s)], dtype = np.int))
            self.dis.append(np.array([float(d) / 10.0]))

        # 輸入數(shù)據(jù)的shape
        self.provide_data = [('dis', (batch_size, 1)),
                             ('series', (batch_size, 1))]
        # 輸出數(shù)據(jù)的shape
        self.provide_label = [('price', (batch_size, )),
                              ('price_interval', (batch_size,))]

    def __iter__(self):
        count = len(self.price)
        for i in range(count / self.batch_size):
            bdis = []
            bseries = []
            blabel = []
            blabel_interval = []
            for j in range(self.batch_size):
                k = i * self.batch_size + j
                bdis.append(self.dis[k])
                bseries.append(self.series[k])
                blabel.append(self.price[k])
                blabel_interval.append(interval(self.price[k]))

            data_all = [mx.nd.array(bdis),
                        mx.nd.array(bseries)]
            label_all = [mx.nd.array(blabel), mx.nd.array(blabel_interval)]
            data_names = ['dis', 'series']
            label_names = ['price', 'price_interval']

            data_batch = Batch(data_names, data_all, label_names, label_all)
            yield data_batch

    def reset(self):
        pass

全部的例子見(jiàn)這里

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 一.互聯(lián)網(wǎng)廣告特征工程 博文《互聯(lián)網(wǎng)廣告綜述之點(diǎn)擊率系統(tǒng)》論述了互聯(lián)網(wǎng)廣告的點(diǎn)擊率系統(tǒng),可以看到,其中的logis...
    jlinleung閱讀 945評(píng)論 0 6
  • 有時(shí)候站在路邊看著人來(lái)人往,會(huì)覺(jué)得城市比沙漠還要荒涼。每個(gè)人都靠的那么近,但完全不知道彼此的心事,那么嘈雜,那么多...
    熙兮晚歸閱讀 227評(píng)論 0 0
  • QUESTION: 開(kāi)始文章前,我想問(wèn)問(wèn)朋友,什么才是對(duì)你最重要的東西? 回到三天之前我問(wèn)了自己同樣的問(wèn)題,我回答...
    小流于江海閱讀 637評(píng)論 0 0
  • 1每天零極限,對(duì)不起,請(qǐng)?jiān)彛x謝你,我愛(ài)你 2冥想,晨起10分鐘放空自我 3極簡(jiǎn)主義,少就是幸福 4讀書(shū)讀書(shū)讀書(shū)...
    蝸牛向上爬啊爬閱讀 253評(píng)論 0 0
  • 經(jīng)常 在某個(gè)忽然的夜里,心里就有好多的感動(dòng),特別想念遠(yuǎn)方的媽媽......文字真是奇妙,有些話當(dāng)著面是說(shuō)不出口的,...
    水之晶閱讀 521評(píng)論 2 2

友情鏈接更多精彩內(nèi)容