LinerUnit

線性單元

  1. 感知器有一個問題,當(dāng)面對的數(shù)據(jù)集不是線性可分的時候,『感知器規(guī)則』可能無法收斂,這意味著我們永遠(yuǎn)也無法完成一個感知器的訓(xùn)練。為了解決這個問題,我們使用一個連續(xù)的線性函數(shù)來替代感知器的階躍函數(shù),這種感知器就叫做線性單元。線性單元在面對線性不可分的數(shù)據(jù)集時,會收斂到一個最佳的近似上。
  2. 那么線性單元就是將感知機(jī)的輸出激活函數(shù)由分段函數(shù)改為了連續(xù)函數(shù),進(jìn)而輸出的值域也由{0,1}\rightarrow[-\infty,+\infty]

舉例說明

當(dāng)我們說模型時,我們實際上在談?wù)摳鶕?jù)輸入x預(yù)測輸出y的算法。比如,x可以是一個人的工作年限,y可以是他的月薪,我們可以用某種算法來根據(jù)一個人的工作年限來預(yù)測他的收入。\\y=w*x+b

其中w,b是可以擬合年限輸入和月薪輸出的待求權(quán)重參數(shù)。工作年限稱為一個特征,輸入可以包含多個特征如:行業(yè),公司,職級等。當(dāng)特征變多時,對應(yīng)的每個特征都需要一個權(quán)重w_i用于擬合輸入和輸出之間的關(guān)系。
\\y = w_1*x_1+w_2*x_2+\dots+w_n*x_n+b,矩陣表示
y=\textbf{W}^T\textbf{X}\\其中
\textbf{W}=\begin{bmatrix} w_i\\ \vdots \\ w_n\\ b \\ \end{bmatrix}, \textbf{X}=\begin{bmatrix} x_i \\ \vdots \\ x_n \\ 1\\ \end{bmatrix}\\

代碼

由于相較于Perceptron只改變了激活函數(shù),所以我們可以繼承Perceptron快速實現(xiàn)LinerUnit

class LinerUnit(Perceptron):
    def __init__(self, input_dim, activator) -> None:
        super().__init__(input_dim, activator)

生成訓(xùn)練數(shù)據(jù),定義可視化

# 新定義的連續(xù)線性激活函數(shù)
def liner_activater(x):
    return x

def get_training_dataset():
    """
    construct training_set, consist of n samples
    Working years and corresponding salary.
    """
    data = [[5], [3], [8], [1.4], [10.1], [8.1]]
    labels = [5500, 2300, 7600, 1800, 11400, 20000]
    return data, labels

def train_liner_unit(iterations, lr):
    """
    Train a liner_unit with training_set.
    """
    lu = LinerUnit(input_dim=1, activator=liner_activater)
    lu.train(*get_training_dataset(), iterations=iterations, lr=lr)
    return lu

def show_results(linear_unit, samples):
    """
    Visualize the line after the linear unit fit
    """
    predicts = [linear_unit.predict(s) for s in samples]
    plt.scatter(samples, predicts, marker="o")
    x_fit = np.linspace(start=0, stop=max(samples), num=100)
    y_fit = linear_unit.weights * x_fit + linear_unit.bias
    plt.plot(x_fit, y_fit, linestyle="-")
    plt.xlabel("Working years")
    plt.ylabel("Salary")
    plt.show()

訓(xùn)練,測試,并可視化

if __name__ == "__main__":
    linear_unit = train_liner_unit(10, 0.1)
    test_samples = [[3.4], [15], [1.5], [6.3], [8]]
    # test
    for year in test_samples:
        print(f"Work {year} years, monthly salary = {linear_unit.predict(year)}")

    show_results(linear_unit=linear_unit, samples=test_samples)

結(jié)果

控制臺輸出.png
可視化結(jié)果.png
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容