有些同學(xué)對(duì)于mxnet的自定義Iter不是很熟悉,對(duì)多輸出也不熟悉,因此我用一個(gè)比較復(fù)雜的例子來(lái)說(shuō)明這個(gè)問(wèn)題:
1. 特征中有連續(xù)特征和離散特征
2. 同時(shí)要解決回歸問(wèn)題和分類問(wèn)題
本著End-to-End的精神,我們不做特征工程,當(dāng)然也就不能做離散化。于是,連續(xù)特征可以直接作為輸入,而離散特征則通過(guò)Embeding的方式輸入。如果要同時(shí)解決回歸和分類問(wèn)題,我們就需要兩個(gè)Loss層。
我們虛構(gòu)一個(gè)簡(jiǎn)單的二手車價(jià)格預(yù)估的問(wèn)題。我們假設(shè)一輛車的價(jià)格只取決于兩個(gè)因素,一個(gè)是車的品牌,一個(gè)是車的里程。不同品牌的車有不同的出廠價(jià)格,而車的行駛里程越長(zhǎng),價(jià)格就會(huì)越低。因此我們可以基于這個(gè)假設(shè),用如下的代碼構(gòu)造一個(gè)數(shù)據(jù)集:
#我們虛構(gòu)了201個(gè)不同的品牌,并給每個(gè)品牌設(shè)置一個(gè)出場(chǎng)價(jià)格
series = [1 + i for i in range(100)] + [101 - i for i in range(100)]
for i in range(10000):
k = random.randint(0, 199)
#越貴的品牌,我們認(rèn)為在數(shù)據(jù)集里出現(xiàn)的次數(shù)越少,因?yàn)樗u的少
count = 1000 / series[k]
for j in range(count):
dis = random.random() * 10
#實(shí)際的價(jià)格是品牌的出場(chǎng)價(jià)除以里程數(shù)的開(kāi)方
price = series[k] / math.sqrt(1.0 + dis)
print str(price) + '\t' + str(dis) + '\t' + str(k)
這里,車的品牌是一個(gè)離散特征,而里程是個(gè)連續(xù)的特征。問(wèn)題的目標(biāo)是,給定品牌和里程,同時(shí)預(yù)測(cè)車的價(jià)格(回歸問(wèn)題),以及車的價(jià)格區(qū)間(分類問(wèn)題)。我們用如下的網(wǎng)絡(luò)來(lái)解決這個(gè)問(wèn)題:
# dis 是輸入的里程
dis = mx.symbol.Variable('dis')
# price 是要預(yù)測(cè)的目標(biāo)價(jià)格
price = mx.symbol.Variable('price')
# price_interval 是要預(yù)測(cè)的價(jià)格區(qū)間
price_interval = mx.symbol.Variable('price_interval')
# series 是輸入的車的品牌
series = mx.symbol.Variable('series')
dis = mx.symbol.Flatten(data = dis, name = "dis_flatten")
series = mx.symbol.Embedding(data = series, input_dim = 200,
output_dim = 100, name = "series_embed")
series = mx.symbol.Flatten(series, name = "series_flatten")
net = mx.symbol.Concat(*[dis, series], dim = 1, name = "concat")
net = mx.symbol.FullyConnected(data = net, num_hidden = 100, name = "fc1")
net = mx.symbol.Activation(data = net, act_type="relu")
net = mx.symbol.FullyConnected(data = net, num_hidden = 100, name = "fc2")
net = mx.symbol.Activation(data = net, act_type="relu")
net = mx.symbol.FullyConnected(data = net, num_hidden = 1, name = "fc3")
# 這里最后為什么用relu呢?是因?yàn)閮r(jià)格一定是個(gè)正數(shù)
net = mx.symbol.Activation(data = net, act_type="relu")
net = mx.symbol.LinearRegressionOutput(data = net, label = price, name = "lro")
net2 = mx.symbol.Concat(*[dis, series], dim = 1, name = "concat")
net2 = mx.symbol.FullyConnected(data = net2, num_hidden = 100, name = "fc21")
net2 = mx.symbol.Activation(data = net2, act_type="relu")
net2 = mx.symbol.FullyConnected(data = net2, num_hidden = 100, name = "fc22")
net2 = mx.symbol.Activation(data = net2, act_type="relu")
net2 = mx.symbol.FullyConnected(data = net2, num_hidden = 8, name = "fc23")
net2 = mx.symbol.Activation(data = net2, act_type="relu")
net2 = mx.symbol.SoftmaxOutput(data = net2, label = price_interval, name="sf")
# 這里net預(yù)測(cè)price,net2預(yù)測(cè)price_interval, 最后group在一起返回
return mx.symbol.Group([net, net2])
這個(gè)例子里,我們需要同時(shí)提供dis, series, price, price_interval 四個(gè)變量。常見(jiàn)的Iter似乎不支持這個(gè)功能,因此可以自己實(shí)現(xiàn)一個(gè):
class PriceIter(mx.io.DataIter):
def __init__(self, fname, batch_size):
super(PriceIter, self).__init__()
self.batch_size = batch_size
self.dis = []
self.series = []
self.price = []
# 這里預(yù)先從文件讀入所有的數(shù)據(jù)存下來(lái)
for line in file(fname):
price, d, s = line.strip().split("\t")
self.price.append(float(price))
self.series.append(np.array([int(s)], dtype = np.int))
self.dis.append(np.array([float(d) / 10.0]))
# 輸入數(shù)據(jù)的shape
self.provide_data = [('dis', (batch_size, 1)),
('series', (batch_size, 1))]
# 輸出數(shù)據(jù)的shape
self.provide_label = [('price', (batch_size, )),
('price_interval', (batch_size,))]
def __iter__(self):
count = len(self.price)
for i in range(count / self.batch_size):
bdis = []
bseries = []
blabel = []
blabel_interval = []
for j in range(self.batch_size):
k = i * self.batch_size + j
bdis.append(self.dis[k])
bseries.append(self.series[k])
blabel.append(self.price[k])
blabel_interval.append(interval(self.price[k]))
data_all = [mx.nd.array(bdis),
mx.nd.array(bseries)]
label_all = [mx.nd.array(blabel), mx.nd.array(blabel_interval)]
data_names = ['dis', 'series']
label_names = ['price', 'price_interval']
data_batch = Batch(data_names, data_all, label_names, label_all)
yield data_batch
def reset(self):
pass
全部的例子見(jiàn)這里