Who this is for
- Machine Learning Engineers who need to ship models to production or share them with teammates.
- Data scientists who hand off trained models to engineering or MLOps.
- Anyone maintaining ML systems across environments (local, cloud, containers).
Prerequisites
- Basic Python (functions, files, virtual environments).
- Familiarity with at least one ML framework: scikit-learn, PyTorch, or TensorFlow/Keras.
- Comfort with training a small model and running inference.
Why this matters
Real-world ML is more than training. You must save models reliably, load them across machines, and keep them working after library updates. Typical tasks:
- Ship a scikit-learn pipeline to a batch scoring job.
- Serve a PyTorch model on CPU even if it was trained on GPU.
- Archive a TensorFlow model to reproduce results months later.
- Migrate a model to another runtime (e.g., ONNX) for faster inference.
Concept explained simply
Model persistence means turning a trained model (parameters + structure + sometimes preprocessing) into files you can store, version, and later load for inference or retraining.
Mental model
- Think of a model as a recipe + ingredients: the architecture (recipe), the weights (ingredients), and the preprocessing (kitchen setup). Saving well means preserving all three so your dish tastes the same later.
- Compatibility lives at three layers: Python version, library/framework version, and file format. The more standard the format, the safer across time and systems.
Core tools and formats you will use
- scikit-learn: joblib.dump / joblib.load for models and Pipelines. Prefer saving a full Pipeline to avoid missing preprocessing steps.
- PyTorch: torch.save(state_dict) and model.load_state_dict(...). Use map_location to move GPU-trained models to CPU. Save model code separately.
- TensorFlow/Keras: model.save("path") to SavedModel (default). Optional H5 format via model.save("model.h5").
- Cross-framework: ONNX for standardized inference across runtimes and languages.
- Artifacts to include: model weights, architecture definition (or code), preprocessing objects, label encoders, and metadata (versions, date, metrics).
Security and safety note
- Never load pickle/joblib files from untrusted sources; loading can execute code. Only load artifacts you trust.
- Store hashes and sign artifacts in production contexts when possible.
Worked examples
Example 1 — scikit-learn Pipeline with joblib
Show code
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
import joblib
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
pipe = Pipeline([
("scaler", StandardScaler()),
("clf", LogisticRegression(max_iter=1000, random_state=42))
])
pipe.fit(X_train, y_train)
pred_before = pipe.predict(X_test)
joblib.dump(pipe, "iris_pipe.joblib")
loaded = joblib.load("iris_pipe.joblib")
pred_after = loaded.predict(X_test)
print("Identical predictions:", np.array_equal(pred_before, pred_after))
Example 2 — PyTorch: save and load state_dict
Show code
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SmallNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 3)
)
def forward(self, x):
return self.net(x)
model = SmallNet().to(device)
with torch.no_grad():
for p in model.parameters():
p.add_(0.01) # fake training
# Save weights only
torch.save(model.state_dict(), "smallnet.pt")
# Load on CPU for inference
loaded = SmallNet()
loaded.load_state_dict(torch.load("smallnet.pt", map_location=torch.device("cpu")))
loaded.eval()
x = torch.randn(5, 4)
with torch.no_grad():
out = loaded(x)
print(out.shape) # should be [5, 3]
Example 3 — Keras: SavedModel and H5
Show code
import tensorflow as tf
from tensorflow import keras
model = keras.Sequential([
keras.layers.Input(shape=(4,)),
keras.layers.Dense(16, activation="relu"),
keras.layers.Dense(3)
])
model.compile(optimizer="adam", loss="mse")
# Fake train step
import numpy as np
X = np.random.randn(50, 4).astype("float32")
y = np.random.randn(50, 3).astype("float32")
model.fit(X, y, epochs=1, verbose=0)
# SavedModel (recommended)
model.save("keras_savedmodel")
# H5 format
model.save("keras_model.h5")
loaded_sm = keras.models.load_model("keras_savedmodel")
loaded_h5 = keras.models.load_model("keras_model.h5")
pred1 = loaded_sm(X[:5])
pred2 = loaded_h5(X[:5])
print("Close predictions:", tf.reduce_all(tf.math.is_finite(pred1 - pred2)).numpy())
Versioning, compatibility, and security
- Record metadata: framework (e.g., scikit-learn 1.4), Python version, training date, dataset/version, metric, and random seeds.
- Prefer stable formats: SavedModel for TF, state_dict + code for PyTorch, joblib for scikit-learn Pipelines, ONNX for serving across stacks.
- Cross-device: load GPU-trained PyTorch weights on CPU with map_location; avoid device-specific tensors in saved files.
- Safety: do not unpickle unknown files. Treat artifacts like executable code.
Quick checklist before you ship
- Includes preprocessing in the artifact (Pipeline or code).
- Deterministic inference vs training: set model to eval() where needed.
- Contains version metadata in a sidecar JSON or within the artifact path naming.
- Tested load in a fresh environment or container.
Exercises lab
These exercises mirror the tasks below. You can do them locally. The quick test at the end is available to everyone; only logged-in users get saved progress.
- Exercise 1: Persist and reload a scikit-learn Pipeline with joblib. Confirm predictions are identical before and after saving.
- Exercise 2: Load a PyTorch model state_dict on CPU and run inference in eval() mode. Confirm the output shape matches expectations.
- Exercise 3: Save a Keras model as both SavedModel and H5, reload both, and verify outputs are numerically close.
Exercise checklist
- Saved files exist on disk with reasonable size.
- Reload works in a new Python session or separate cell.
- Inference parity: same or very close predictions.
- Metadata recorded somewhere (e.g., a small JSON next to the model).
Common mistakes and self-check
- Saving only the estimator but not preprocessing. Self-check: can you call predict() on raw input and get expected results after reload?
- Forgetting model.eval() in PyTorch. Self-check: are results inconsistent across runs without weight changes?
- Assuming pickle is safe. Self-check: only load artifacts created by your team/build pipeline.
- GPU-only checkpoints. Self-check: can you load with map_location="cpu" successfully?
- No version metadata. Self-check: can someone reproduce this load in six months without asking you?
Practical projects
- Model bundle CLI: a small script that trains, saves an artifact, writes a metadata JSON (framework, versions, metrics), and verifies a round-trip load.
- Framework bridge: export a model to ONNX and run a parity check against the original framework on a test set.
- Environment smoke test: minimal Docker or fresh venv that loads your artifact and runs 10 inferences.
Mini challenge
You trained a model on GPU in PyTorch but must serve on CPU-only VMs tomorrow. Create a loading snippet that guarantees CPU inference, sets eval(), and verifies the first prediction equals a stored reference vector within a small tolerance.
Hint
- Use map_location when loading and torch.allclose for parity check.
Learning path
- Start: Save/load within one framework (joblib, torch.save, model.save).
- Next: Add preprocessing into the artifact (Pipelines, tf.Transform, custom preprocess code).
- Then: Versioning and metadata, environment parity checks.
- Advanced: Cross-framework export (ONNX), quantization-aware saving, secure artifact handling.
Next steps
- Complete the exercises and the quick test below.
- Integrate persistence into your training scripts so every run emits a reproducible artifact.
- Plan a migration path for your current models (e.g., add metadata, test CPU loads).