【學(xué)習(xí)tensorflow2】有用的API匯總

  1. 動(dòng)態(tài)顯存分配
from tensorflow.compat.v1 import ConfigProto, InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
  1. 隨機(jī)數(shù)種子: 為
tf.random.set_seed(2317)
  1. 混合精度
opt = Adam()
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
model.compile(optimizer=opt, loss="...")
  1. 中斷訓(xùn)練與繼續(xù)訓(xùn)練
reloaded = False
# 參數(shù)為鍵值對(duì), 如global_epoch=global_epoch, 等式左邊是key(自行定義), 右邊是value(tf的變量, 模型, 優(yōu)化器等).
checkpoint = tf.train.Checkpoint(global_epoch=global_epoch, model=model)
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if reloaded:
    checkpoint.restore(manager.latest_checkpoint)
while True:
    # Train.
    manager.save()
  1. 日志可視化
log_writer = tf.summary.create_file_writer(log_dir)
def write_log(l, name):
    with log_writer.as_default():
        tf.summary.scalar(name, l, step=global_epoch)
    log_writer.flush()
# 使用tensorboard --logdir [log_dir]可視化日志.
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容