Checkpointing
Overview¶
Checkpointing is the process of saving the current state of a running job at regular intervals so that it can be resumed later from that state, rather than starting from scratch. This is especially useful in long-running or resource-intensive tasks on HPC systems like Wulver, where interruptions or failures may occur.
Checkpointing typically involves:
- Periodic saving of application state (memory, variables, file handles, etc.)
- Resuming computation from the last saved state
- Integration with SLURM job re-submission or recovery workflows
Why Use Checkpointing?¶
Benefit | Description |
---|---|
Failure Recovery | Resume jobs from the last checkpoint after a node crash or time expiration. |
Efficient Resource Use | Prevents waste of computation time on long jobs that are interrupted. |
Preemption Tolerance | Helps tolerate job preemption on shared clusters or spot instances. |
Job Time Limit Bypass | Breaks large jobs into smaller chunks to fit within SLURM time limits. |
Examples for checkpointing¶
Save intermediate state using Python’s built-in pickle
module — ideal for lightweight scripts.
import pickle
import time
def save_checkpoint(data, filename="checkpoint.pkl"):
with open(filename, "wb") as f:
pickle.dump(data, f)
def load_checkpoint(filename="checkpoint.pkl"):
try:
with open(filename, "rb") as f:
return pickle.load(f)
except FileNotFoundError:
return {"iteration": 0}
state = load_checkpoint()
for i in range(state["iteration"], 1000):
# Do some work
time.sleep(1)
print(f"Running step {i}")
# Save progress every 100 steps
if i % 100 == 0:
save_checkpoint({"iteration": i})
A common practice in PyTorch to checkpoint model weights, optimizer state, and epoch index — useful for training recovery.
import torch
# Save checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, 'checkpoint.pth')
# Load checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
Using Keras callbacks, checkpoints are saved automatically during training. Only model weights are saved to keep storage efficient.
import tensorflow as tf
model = tf.keras.models.Sequential([...])
checkpoint_path = "checkpoints/model.ckpt"
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
save_freq='epoch')
model.fit(data, labels, epochs=10, callbacks=[checkpoint_cb])
In C/C++, you can implement basic checkpointing by writing a loop index or state to a file and loading it at the next run. Ideal for simple simulations or compute-intensive loops.
int main() {
int current_step = load_checkpoint("state.dat"); // Custom function
for (int i = current_step; i < MAX; i++) {
// Work
if (i % 1000 == 0) {
save_checkpoint(i, "state.dat");
}
}
}
GROMACS supports checkpointing with .cpt
files during molecular simulations:
gmx mdrun -deffnm simulation -cpt 15 # Saves checkpoints every 15 minutes
# Restart from checkpoint
gmx mdrun -s topol.tpr -cpi simulation.cpt
LAMMPS checkpointing is usually done using write_restart
and read_restart
:
write_restart restart.equilibration
# In a new job
read_restart restart.equilibration
Checkpointing is done by writing time steps to disk and restarting from a previous time directory:
# Set in controlDict
writeInterval 100;
# Restart
startFrom latestTime;