1. What is On-device learning?[edit | edit source]
On-device learning refers to the process of training a machine learning model directly on an edge device where the data is being collected without relying on cloud services or external servers. This is in contrast to training a model on a server or a cloud. This approach is becoming increasingly popular and has several advantages:
- Enhancing data confidentiality and personal privacy protection for the sensitive data that cannot be transmitted over the network, reducing the risk of interception.
- Personalization of neural network models: On-device learning involves fine-tuning models on-device to generate new personalized models. The training utilizes on-device data, resulting in a model tailored specifically for the end user. This personalized model is then used for inference, improving the user experience for the end customer to allow NN models to continuously learn and adapt to new data collected on the device.
- Offline capability: Training can be performed without an internet connection, making it suitable for remote or offline scenarios.
On-device learning is a powerful approach that leverages the capabilities of modern mobile and edge devices to provide personalized, secure, and efficient machine learning solutions. It can be used in scenarios like model personalization, knowledge distillation, teacher-student machine learning and even federated learning.
2. On-device learning using ONNXRuntime[edit | edit source]
The new on-device learning feature enhances the ONNXRuntime [1] inference capabilities by allowing training directly on edge devices. This aims to simplify the process to take an inference model and fine tune it locally using on-device data. The figure below shows the high-level workflow of On-device learning using ONNXRuntime framework.
With on-device learning feature integrated within ONNXRuntime [1], application developers can now perform inference and training using the same binaries. After a training session, the runtime generates optimized, inference-ready models that can be used to deliver a more personalized experience on the device depending on the use case. This task is basically based on two main phases:
- The offline phase: where training graphs with backward propagation are prepared on a host computer based on a ONNX inference model.
- The training phase: where the training loop is called on device, the training is taking place and the weights are being updated locally.
The following figure sums up the workflow of the on-device learning using ONNXRuntime training framework:
2.1. The offline phase[edit | edit source]
This phase involves preparing prerequisite files for the actual training. Typically, this task is performed on a server or the user's computer as an offline step. The files generated during this phase will be used during the on-device training. These files include:
- The training graph: an ONNX model file containing model’s forward and gradient computations along the loss function.
- The evaluation graph: an ONNX model that shares the same structure as the forward graph, but it incorporates a loss function at the end of the graph.
- The optimizer graph: an ONNX model that contains operations such as gradient normalization and parameter updates.
- The checkpoint file: a file containing the essential training state such as trainable and non-trainable parameters. This parameters are pointed to by all the model graphs mentioned above to reduce the model size.
These training artifacts can be generated using ONNX Python tool, refer to How to generate training artifacts article to get a hands on experience.
2.2. The training phase[edit | edit source]
The first step consists of loading the NN model along with the training artifacts, which are the training, the evaluation and the optimizer subgraphs. For this purpose, we use the classes Module to load the training and evaluation graphs, Optimizer to load the optimizer, and CheckpointState to load the previously pre-trained weights, if available, from the onnxruntime-training Python module. The training loop is then launched by calling successively the training graph to compute the gradients, then the optimizer graph to update the model parameters based on the computed gradients and finally the evaluation graph on the evaluation dataset to compute the validation loss. The user should use a sufficiently high number of epochs to ensure the loss function convergence.
3. On-device Learning limitations[edit | edit source]
While on-device learning offers several advantages, such as improved privacy, reduced latency, and lower bandwidth usage, it also comes with certain limitations. Here are some key limitations of on-device learning:
- Limited computational resources: Edge devices typically have constraints related to processing power and memory which can limit the complexity and size of the models that can be trained and the batch size used for training.
- Catastrophic forgetting: Also known as catastrophic interference, is a phenomenon in machine learning where a model forgets previously learned information upon learning new information. This can be solved by using some techniques such as knowledge distillation.
4. Teacher student use case[edit | edit source]
Teacher-student machine learning with automatic labeling is an interesting use case of on-device learning. It is an advanced technique where a large, well-trained model the teacher is used to generate labels for a dataset that lacks annotations. The teacher model predicts labels for the unlabeled data, which are then used to train a smaller, simpler model for the student. This process allows the student model to learn from the teacher's knowledge, effectively transferring the teacher's expertise to the student while keeping its inference time constraints.
This method is particularly useful for creating efficient models for deployment on resource-constrained devices, such as STM32MP2 series' boards , while also reducing the need for and effort involved in manual data labeling.
5. References[edit | edit source]