机器学习速成大法之断点续传篇
我们在下载文件时经常会用到断点续传,以至于下载暂停再开始时不用从头来一次。而机器学习里最耗时的操作无疑是训练,有的模型数据量大,要用高档GPU服务器训练好几天才能出结果。这中途如果因为断电或故障关机等原因造成训练中断,不能像下载中断断点续传一样的话,程序员估计会气的砸电脑。那么本篇我就带着大家来学习,如何将训练中的模型保存下来,并在需要继续训练时恢复出来使用。(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() train_labels = train_labels[:1000] test_labels = test_labels[:1000] train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0 test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 # Define a simple sequential model def create_model(): model = tf.keras.models.Sequential([ keras.layers.Dense(512, activation="relu", input_shape=(784,)), keras.layers.Dropout(0.2), keras.layers.Dense(10) ]) model.compile(optimizer="adam", loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[tf.metrics.SparseCategoricalAccuracy()]) return model # Create a basic model instance model = create_model() # Display the model"s architecture model.summary()
首先直接用手写数字识别的代码,构建和编译模型,然后在开始训练之前设置模型检查点回调函数,当训练开始后,tensorflow会调用这个回调函数,实时将已训练的状态信息保存到这个检查点里。checkpoint_path = "training_1/cp.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) # Create a callback that saves the model"s weights cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1) # Train the model with the new callback model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels), callbacks=[cp_callback]) # Pass callback to training
我们到检查点路径里,可以看到保存文件的信息。os.listdir(checkpoint_dir) ["cp.ckpt.index", "cp.ckpt.data-00000-of-00001", "checkpoint"]
这是我们试试不恢复已训练状态,直接构建模型,不训练,来预测结果看看,正确率只有11.5%,非常的低。# Create a basic model instance model = create_model() # Evaluate the model loss, acc = model.evaluate(test_images, test_labels, verbose=2) print("Untrained model, accuracy: {:5.2f}%".format(100 * acc)) 32/32 - 0s - loss: 2.3609 - sparse_categorical_accuracy: 0.1150 Untrained model, accuracy: 11.50%
现在我们试试将之前训练状态恢复,不进行再次训练,来预测结果看看,正确率86.4%,非常的高。这说明直接复用了上一次的训练成果,实现断点续传了。# Loads the weights model.load_weights(checkpoint_path) # Re-evaluate the model loss, acc = model.evaluate(test_images, test_labels, verbose=2) print("Restored model, accuracy: {:5.2f}%".format(100 * acc)) 32/32 - 0s - loss: 0.4329 - sparse_categorical_accuracy: 0.8640 Restored model, accuracy: 86.40%
上述代码是通过设置检查点checkpoint回调函数实现自动保存,我们也可以在需要时手动保存模型状态和加载已有模型的状态信息。代码如下:# Save the weights model.save_weights("./checkpoints/my_checkpoint") # Create a new model instance model = create_model() # Restore the weights model.load_weights("./checkpoints/my_checkpoint") # Evaluate the model loss, acc = model.evaluate(test_images, test_labels, verbose=2) print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
有时候我们想把训练好的模型发给别人使用,而别人并不知道我们构建模型所使用的代码怎么办呢?这时候就要求我们不仅仅是保存模型权重等状态信息了,而是要把整个模型保存下来,需要使用model.save()函数,保存的文件里就包含了模型构建、编译、训练的全部数据,别人拿去后可以直接加载使用。# Create and train a new model instance. model = create_model() model.fit(train_images, train_labels, epochs=5) # Save the entire model as a SavedModel. !mkdir -p saved_model model.save("saved_model/my_model")
拿到一个完整的模型数据后,直接使用load_model加载出来,然后就可以用于预测结果了,省去了模型构建和编译过程。有了这个功能,才能让我们站在巨人的肩膀人,如果大家都把自己训练好的完整模型开源共享出来,那么全世界的数据中心机房将节省几亿度电。new_model = tf.keras.models.load_model("saved_model/my_model") # Check its architecture new_model.summary()