代碼補全快餐教程(1) - 30行代碼見證奇跡
下面是我用30多行代碼,包含了很多空行和注釋的代碼寫成的代碼補全模型。我們先看看效果吧。
補全效果案例
先來看個比較普通的(Python, Keras)
已知:
y_train = keras.utils.to_categorical(y_train, num_classes)\ny_test = keras.
補全之后是這樣的:
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
算法能夠知道把括號中的y_train換成y_test
再看一個把我感動哭了的(Typescript, vscode)
輸入如下:
text = "let disposable_begin_buffer = vscode.commands.registerCommand('extension.littleemacs.beginningOfBuffer',\nmove.beginningOfBuffer);\nlet disposable_end_buffer = vscode.commands."
輸出是這樣的:
let disposable_begin_buffer = vscode.commands.registerCommand('extension.littleemacs.beginningOfBuffer',move.beginningOfBuffer);
let disposable_end_buffer = vscode.commands.registerCommand('extension.littleemacs.endendOfBuffer',move.endendOfBuffer);
請注間這其中的難度,變量定義中用的是begin,而extension和move中用的都是beginning,算法能將其換成endend而保持OfBuffer不變。
函數(shù)的補全(Java)
輸入如下:
public class Issue {\nprivate Long id;\nprivate String filename;\nprivate Long lineNum;\nprivate String issueString;\npublic Long getId() {
輸出如下:
public class Issue {
private Long id;
private String filename;
private Long lineNum;
private String issueString;
public Long getId() { return id; }
對于IntellJ IDEA來說這不算什么,但是對于完全不懂Java語言的文本模型,隔著幾行其它變量能把return id給補全出來還是很了不起的
能看懂循環(huán)(Java, Android)
題目取自Android源代碼:final int N = a.getIndexCount();\nfor (int i = 0; i < N; i++) {\nint index = a.
補全結(jié)果如下:
final int N = a.getIndexCount();
for (int i = 0; i < N; i++) {
int attr = a.getAttribute(i);
從上下文看,鬼知道a是個啥類型,但是這個模型可以根據(jù)變量attr猜出來getAttribute。
話說我的attr想用a.getIndex獲取怎么辦?不用重新訓(xùn)練了,給模型寫個例子就好了:
輸入如下:
int attr = a.getIndex(i);\nfor (int i = 0; i < N; i++) {\nint attr = a.
補全的結(jié)果就是這樣了:
int attr = a.getIndex(i);
for (int i = 0; i < N; i++) {
int attr = a.getIndex(i);
還會給變量名做加法!(Python, Keras)
輸入如下:
tower_1 = Conv2D(64, (1, 1), padding='same', activation='relu')(input_img)\ntower_1 = Conv2D(
見證奇跡的時刻到了:
tower_1 = Conv2D(64, (1, 1), padding='same', activation='relu')(input_img) tower_2 = Conv2D(64
也就是補完tower_1之后,算法不過癮開始命名了一個tower_2繼續(xù)補!
這對卷積一寫好多層太方便了有沒有?
這個加法對于變量名管用,對于下標(biāo)變量也有效。
我們再來看個例子:
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
不管是用y_train還是x_test做補全時,shape后面的[0]就會被加1變成[1],在這里是不適用的,補完了需要手工補一下。
多語言支持 (common lisp)
為了證明不是針對Java和Python做的特殊優(yōu)化,我們來個lisp的例子看下:
text = '(progn (setq a (func1 0)) (setq b'
輸出如下:
(progn (setq a (func1 0)) (setq b (func2 0)) (setq c (func3 0)) (setq d (func4 0))
不但知道func1加1變func2,setq的變量,也從b變成c,d以此類推。
30行代碼創(chuàng)造奇跡
上面這樣神奇的功能,我們借助最新的神經(jīng)網(wǎng)絡(luò)自然處理的工具,只有30多行代碼就可以搞定:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
# 加載詞匯表
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 輸入待補全的文本
text = 'int idx = a.getIndex(i);\nfor (int i = 0; i < N; i++) {\nint attr = a.'
predicted_text = text
# 每一個只能補一個token出來,補一句話需要多次,30次是我拍腦袋的
for i in range(0,30):
# 以上次預(yù)測結(jié)果作為本次的輸入,所謂的自回歸
indexed_tokens = tokenizer.encode(predicted_text)
# 將讀出的索引標(biāo)記轉(zhuǎn)化成PyTorch向量
tokens_tensor = torch.tensor([indexed_tokens])
# 加載模型中預(yù)訓(xùn)練好的權(quán)值
model = GPT2LMHeadModel.from_pretrained('gpt2')
# 設(shè)置為eval模式,這樣就不會執(zhí)行訓(xùn)練模式下的Dropout過程
model.eval()
# 使用GPU進行加速,誠實地講速度不太快
tokens_tensor = tokens_tensor.to('cuda')
model.to('cuda')
# 進行推理
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0]
# 獲取預(yù)測的下一個子詞
predicted_index = torch.argmax(predictions[0, -1, :]).item()
# 解碼成我們都讀懂的文本
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
# 打印輸入結(jié)果
print(predicted_text)
用來自動寫作
其實,上面所用的gpt-2模型,并不是給代碼補全用的,用來自動寫點的東西到時它的本業(yè)。
比如大家可以試試,給“To be or not to be"補全下,我的結(jié)果如下“To be or not to be, the only thing that matters is that you're a good person.”
再比如“I have a dream that one day”,我的結(jié)果如下“I have a dream that one day I will be able to live in a world where I can be a part of something bigger than myself.”
如果不想寫代碼的話,可以直接在https://transformer.huggingface.co/doc/gpt2-large中去直接試驗。
如下圖所示,寫代碼寫文字都可以:

安裝環(huán)境
如果想試用上面的代碼的話,只需要安裝transformers庫就好了。
pip install transformers
另外,transformers庫依賴PyTorch或Tensorflow之一,我們上面的代碼是基于PyTorch的,還需要安裝一下PyTorch:
pip3 install torch torchvision
在Windows下安裝命令稍有不同,需要指定版本號,例:
pip3 install torch===1.3.0 torchvision===0.4.1 -f https://download.pytorch.org/whl/torch_stable.html