在使用tf-gpu1.15 object detection api 中發(fā)現(xiàn)無論是訓練模型還是通過模型測試單張圖片的時候都會出現(xiàn)這個問題。
原因:TensorFlow默認的是占用全部gpu,這樣就會導致后面的卷積運算中沒有足夠的內存進行計算
解決:通過以下代碼更改解決
train.py
if __name__ == '__main__':
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
tf.app.run()
object_detection_tutorial.py
with detection_graph.as_default():
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# with tf.Session(config=config) as session:
with tf.Session(graph=detection_graph,config=config) as sess: