Last edited 6 days ago

How to generate training artifacts for on-device learning feature

Applicable for STM32MP23x lines, STM32MP25x lines


1. Article purpose[edit | edit source]

This article demonstrates how to generate the training artifacts on your host computer as a part of the offline phase of an on-device learning workflow.

2. PC Prerequisites[edit | edit source]

To generate the artifacts, you must install some Python modules on your host computer.


onnxruntime-training module is required for generating the training artifacts:

pip3 install onnxruntime-training==1.19.2

Torchvision [1] is required to load pre-trained models and datasets:

pip3 install torchvision

3. List of loss functions supported by ONNXRuntime[edit | edit source]

As specified in the ONNXRuntime training API [2], only the following loss functions are supported by ONNXRuntime:

class LossType(Enum):
    """Loss type to be added to the training model.
    To be used with the `loss` parameter of `generate_artifacts` function.
    """
    MSELoss = 1
    CrossEntropyLoss = 2
    BCEWithLogitsLoss = 3
    L1Loss = 4

4. Generating training artifacts[edit | edit source]

To demonstrate the process of training artifacts generation, two different use cases are described:

  1. Predefined loss function: uses a simple MobileNet V2 model for image classification with a predefined loss.
  2. Custom loss function: uses an SSD MobileNet V2 model for object detection, the loss function is more complex to define.

4.1. Generating training artifacts with predefined loss[edit | edit source]

In this section, we import a MobileNet V2 model for image classification from torchvision library, to generate its training artifacts with an ONNX runtime predefined loss function, such as the CrossEntropyLoss.

4.1.1. Exporting model to ONNX format[edit | edit source]

We start by defining and exporting the model to ONNX format. The original model is trained on IMAGENET1K V2 dataset, but for this example we are adding a classifier with two classes for binary classification.

import torch
import torchvision

model = torchvision.models.mobilenet_v2(
weights = torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)
num_classes = 2

model.classifier[1] = torch.nn.Linear(1280, 2)

# Export the model to ONNX.
model_name = "mobilenet_v2"
torch.onnx.export(model, torch.randn(1, 3, 224, 224),
                  f"training_artifacts/{model_name}.onnx",
                  input_names=["input"], output_names=["output"],
                  dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})

4.1.2. Generating training artifacts for MobileNet V2[edit | edit source]

Now that we have exported our mobilenet_v2.onnx forward graph model, we can generate the training artifacts.


import onnx
from onnxruntime.training import artifacts

# Load the onnx model.
onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")

requires_grad = ["classifier.1.weight", "classifier.1.bias", "onnx::Conv_691", "onnx::Conv_692"]
frozen_params = [
                 param.name
                 for param in onnx_model.graph.initializer
                     if param.name not in requires_grad
                ]

# Generate the training artifacts.
artifacts.generate_artifacts(
       onnx_model,
       requires_grad=requires_grad,
       frozen_params=frozen_params,
       loss=artifacts.LossType.CrossEntropyLoss,
       optimizer=artifacts.OptimType.AdamW,
       artifact_directory="training_artifacts")
Warning DB.png Important
For this example, we suppose that we need to retrain the classifier along with the last convolutional layer to obtain a better accuracy after fine tuning the model on target. You are free to explore the fine tuning of the model by unfreezing other layers, but you should keep in mind the RAM constraints of the device.

Four files are generated. The training graph computes the gradients, the optimizer graph updates the model parameters based on the computed gradients, and the evaluation graph on the evaluation dataset computes the validation loss. The checkpoints file contains the essential training state, such as trainable and non-trainable parameters.

Now that the training artifacts have been generated, we can deploy them into device, to fine tune the neural network model on device.

4.2. Generate training artifacts with custom loss[edit | edit source]

For models where the loss function is more complex, the idea is to have a PyTorch model and make the exported ONNX model contain the loss function in its forward graph. This makes it possible to generate training graphs for a model where the loss function is not supported by ONNXRuntime.
To illustrate this, we will generate the training artifacts for the SSD MobileNet V2 model used by the On-device learning object detection applications.
For the sake of simplicity, we will use the SSD MobileNet V2[3], where the loss function has been included to the forward graph. The first step consists in cloning the repository to our local machine.

git clone https://github.com/stm32-hotspot/pytorch-ssd
cd pytorch-ssd


Info white.png Information
For this example we will use the weight of the MobileNet V2 backbone trained on ImageNet dataset. You can choose the weights depending on your use case. The repository provides a train_ssd.py script to train your model on your local machine as an initial training before exporting to ONNX, and generating its training artifacts for on-device learning.

4.2.1. Exporting the model with loss function to ONNX format[edit | edit source]

To create the forward graph with the loss function integrated, run the following script. You will generate a forward model with regular outputs and loss outputs in the graph as an ONNX model.

# Copyright (c) 2025 STMicroelectronics. All rights reserved.
#

import argparse
from onnxruntime.training import artifacts
import onnx
import torch
import os

from vision.ssd.mobilenet_v2_ssd_lite import create_mobilenetv2_ssd_lite
from vision.nn.multibox_loss import MultiboxLoss
from vision.ssd.config import mobilenetv1_ssd_config
from vision.utils.box_utils import generate_ssd_priors


def main():
    parser = argparse.ArgumentParser(
                        description='Script to help with the export of SSDMobileNetV2 model with loss function.')
    parser.add_argument('--weights_path',
                        help='Path of the weights file to use for the basenet.')
    parser.add_argument('--nb_classes',
                        help='Number of classes to be predicted. Default 1.', type=int, default=1)
    parser.add_argument('--img_size', default=300, type=int,
                        help='Desired image size for training. Default = 300.')
    parser.add_argument('--output_model_name',
                        help='Name the exported model. Default = ssd_model_with_outputs_and_loss.', default="ssd_model_with_outputs_and_loss")
    parser.add_argument('--onnx_opset',
                        help='Opset version to use for model export. Default = 17.', type=int, default=17)
    parser.add_argument('--iou_threshold',
                        help='IOU threshold for the loss function. Default = 0.5.', type=int, default=0.5)
    args = parser.parse_args()

    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Using {device} device")

    target_filenames = ['mb2-ssd-lite-mp-0_686.pth', 'mb2-imagenet-71_8.pth']
    weights_filename = os.path.basename(args.weights_path)
    if weights_filename not in target_filenames:
        print("Weights filename should be either mb2-ssd-lite-mp-0_686.pth or mb2-imagenet-71_8.pth")
        return False

    img_size = args.img_size

    # Load a model (with pretrained weights)
    ssd_model = create_mobilenetv2_ssd_lite(num_classes=args.nb_classes + 1, onnx_compatible=True)

    class Mbv2SSDnWithLoss(torch.nn.Module):
        def __init__(self, ssd_model, train_mode=True):
            super(Mbv2SSDnWithLoss, self).__init__()
            self.model = ssd_model
            self.model.train(train_mode)
            self.criterion = MultiboxLoss(generate_ssd_priors(mobilenetv1_ssd_config.specs, img_size), iou_threshold=args.iou_threshold, neg_pos_ratio=3,
                                          center_variance=0.1, size_variance=0.2, device=device, export_to_onnx=True)

            # Initialize the base net with pretrained weights
            if weights_filename == target_filenames[0]:
                self.model.init_from_pretrained_ssd(args.weights_path)
            else:
                self.model.init_from_base_net(args.weights_path)

        def forward(self, batch, labels, in_boxes):
            confs, out_boxes = self.model(batch)
            regression_loss, classification_loss = self.criterion(
                confs, out_boxes, labels, in_boxes)
            loss = regression_loss + classification_loss
            # Loss should be the first output in order to correctly generate the artifacts
            return loss, confs, out_boxes

    ssd_model_with_loss = Mbv2SSDnWithLoss(ssd_model, train_mode=True)

    # Generate anchors
    anchors = generate_ssd_priors(mobilenetv1_ssd_config.specs, img_size)
    nb_anchors = anchors.size()[0]

    # The number of anchors depends on the model architecture and the image size
    # Set training option to torch.onnx.TrainingMode.TRAINING to export the model in training friendly mode
    torch.onnx.export(ssd_model_with_loss, (torch.randn(3, 3, img_size, img_size), torch.zeros((3, nb_anchors)), torch.rand(3, nb_anchors, 4)),
                      f"{args.output_model_name}.onnx",
                      input_names=["images", "labels", "in_boxes"], output_names=["loss", "confs", "out_boxes"],
                      dynamic_axes={"images": {0: "batch"},
                                    "labels": {0: "batch", 1: "priors"},
                                    "in_boxes": {0: "batch", 1: "priors", 2: "coordinates"},
                                    "out_boxes": {0: "batch", 1: "priors", 2: "coordinates"},
                                    "confs": {0: "batch", 1: "priors", 2: "confidences"},
                                    "loss": {0: "batch"}}, training=torch.onnx.TrainingMode.TRAINING,
                                    do_constant_folding=False, export_params=True, opset_version=args.onnx_opset)

if __name__ == '__main__':
    main()
Info white.png Information
It is required to generate the SSD model anchors to determine the output shape of the model. To retrain the model graph in ONNXRuntime, the training argument in the export function must be set to torch.onnx.TrainingMode.TRAINING. For onnxruntime-training 1.19.2 the opset version must be higher than 17.

This is equivalent to running the following script:

python3 export_ssd_to_onnx.py --weights_path models/mb2-imagenet-71_8.pth --nb_classes 2 --img_size 300 --output_model_name    ssd_mobilenetv2_with_loss --onnx_opset 17

Notice the creation of a new ssd_mobilenetv2_with_loss.onnx model file.

Warning white.png Warning
The BACKGROUND class is included by default, you should only put the number of the classes you are willing to training your model for.
Info white.png Information
the --weights_path expects the path to mb2-ssd-lite-mp-0_686.pth or mb2-imagenet-71_8.pth. The first are the weights of the whole network pretrained on PASCAL VOC dataset, the second are the weights of the basenet (MobileNetV2) pretrained on IMAGENET dataset.

4.2.2. Generating training artifacts for SSD MobileNet V2[edit | edit source]

After exporting the model containing the loss outputs in the graph to the ONNX format, we can generate the training artifacts for this model by running the following script:

import argparse
from onnxruntime.training import artifacts
from config import optimized_grad
import onnxruntime.training.onnxblock as onnxblock
import onnx
import numpy as np

# Load the onnx model.
model_path = args.onnx_model_path
onnx_model = onnx.load(model_path)

for node in onnx_model.graph.node:
if node.op_type == "BatchNormalization":
   for attribute in node.attribute:
       if attribute.name == 'training_mode':
           if attribute.i == 1:
               node.output.remove(node.output[1])
               node.output.remove(node.output[1])
           attribute.i = 0

# Freeze all layers but header
freeze_net = False

# Freeze only the basenet model (MobileNetV2 in this case)
freeze_basenet = False

# Freeze layers depending on an optimized scheme (config.py)
freeze_optimized = True

if freeze_optimized:
requires_grad = [param.name for param in onnx_model.graph.initializer
                 if param.name in optimized_grad]
elif freeze_net:
requires_grad = [param.name for param in onnx_model.graph.initializer
                 if "extras" not in param.name
                 and "base_net" not in param.name
                 and "running_mean" not in param.name
                 and "running_var" not in param.name]
elif freeze_basenet:
requires_grad = [param.name for param in onnx_model.graph.initializer
                 if "base_net" not in param.name
                 and "running_mean" not in param.name
                 and "running_var" not in param.name]
else:
requires_grad = [param.name for param in onnx_model.graph.initializer
                 if "running_mean" not in param.name
                 and "running_var" not in param.name]

frozen_params = [param.name for param in onnx_model.graph.initializer if param.name not in requires_grad]

# Generate the training artifacts.
artifacts.generate_artifacts(onnx_model,
                             loss=None,
                             optimizer=onnxblock.optim.AdamW(weight_decay=5e-4, eps=1e-8),
                             requires_grad=requires_grad,
                             frozen_params=frozen_params,
                             artifact_directory=args.artifacts_dir_path)
Warning white.png Warning
As done in the generate_artifacts.py script, remove the running_mean and running_var in the batch normalization layers from the ONNX graph to avoid shape inference error.
Warning DB.png Important
It is important to keep the loss argument set to None in the artifacts.generate_artifacts function, as the loss function has been already embedded into the forward graph.

This is equivalent to running the following commands:

mkdir training_artifacts
python3 generate_artifacts.py ./ssd_mobilenetv2_with_loss.onnx ./training_artifacts --freeze_optimized
Info white.png Information
This script provides three options to choose the layers for which the gradients are computed:
  • --freeze_net: computing gradients for all layers except headers.
  • --freeze_basenet: computing gradients only for the base network model, the MobileNet V2 in this case.
  • --freeze_optimized: computing gradients for a subset of layers for optimized training.

Four ONNX files are generated. The training graph computes the gradients, the optimizer graph updates the model parameters based on the computed gradients, and the evaluation graph on the evaluation dataset computes the validation loss. The checkpoints file contains the essential training state such as trainable and non-trainable parameters.

Now that the training artifacts have been generated, we can deploy them into device to fine tune the neural network model on device.

5. References[edit | edit source]