Integrating Machine Learning Models in Flutter with TensorFlow Lite

Flutter, Google’s UI toolkit for building natively compiled applications for mobile, web, and desktop from a single codebase, has gained significant traction in the developer community. One of the exciting applications of Flutter is its ability to integrate machine learning (ML) models seamlessly. TensorFlow Lite (TFLite) is a set of tools that enables on-device machine learning inference, allowing developers to run ML models directly on mobile devices. This blog post will guide you through integrating machine learning models in Flutter using TensorFlow Lite.

What is TensorFlow Lite?

TensorFlow Lite is TensorFlow’s lightweight solution for mobile and embedded devices. It enables on-device machine learning inference with low latency, while maintaining a small binary size. TFLite supports various types of ML models, including image classification, object detection, and natural language processing.

Why Use TensorFlow Lite in Flutter?

  • On-Device Inference: Runs ML models directly on the device without needing a network connection.
  • Low Latency: Optimized for performance to provide real-time or near real-time results.
  • Privacy: Keeps user data local and avoids sending it to external servers.
  • Efficiency: Reduces the computational burden on backend servers, leading to cost savings.

Steps to Integrate Machine Learning Models in Flutter with TensorFlow Lite

Follow these steps to integrate an ML model into a Flutter application using TensorFlow Lite:

Step 1: Prepare Your TensorFlow Lite Model

First, you need a pre-trained TensorFlow Lite model (.tflite file). You can either create your own using TensorFlow or download a pre-trained model. For demonstration purposes, let’s use a simple image classification model.

Step 2: Add the TensorFlow Lite Flutter Plugin

Add the tflite plugin to your pubspec.yaml file:

dependencies:
  flutter:
    sdk: flutter
  tflite: ^1.1.2 # Use the latest version

Run flutter pub get to install the plugin.

Step 3: Import the TensorFlow Lite Model

Create an assets folder in your Flutter project and copy your .tflite model and label file (.txt file containing class labels) into it. Then, update your pubspec.yaml file to include these assets:

flutter:
  assets:
    - assets/model.tflite
    - assets/labels.txt

Step 4: Implement the Flutter Application

Here’s a basic example of how to implement image classification in Flutter using TensorFlow Lite:

import 'dart:io';
import 'package:flutter/material.dart';
import 'package:image_picker/image_picker.dart';
import 'package:tflite/tflite.dart';

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'TFLite Demo',
      theme: ThemeData(
        primarySwatch: Colors.blue,
      ),
      home: MyHomePage(),
    );
  }
}

class MyHomePage extends StatefulWidget {
  @override
  _MyHomePageState createState() => _MyHomePageState();
}

class _MyHomePageState extends State {
  File? _image;
  List? _outputs;
  bool _loading = false;

  @override
  void initState() {
    super.initState();
    _loading = true;
    loadModel().then((value) {
      setState(() {
        _loading = false;
      });
    });
  }

  loadModel() async {
    try {
      await Tflite.loadModel(
          model: "assets/model.tflite",
          labels: "assets/labels.txt",
          numThreads: 1, // defaults to 1
          isAsset:
              true, // defaults to true, set to false to load resources outside assets
          useGpuDelegate:
              false // defaults to false, set to true to use GPU delegate
          );
    } catch (e) {
      print('Failed to load model. $e');
    }
  }

  classifyImage(File image) async {
    var output = await Tflite.runModelOnImage(
        path: image.path,
        imageMean: 0.0, // defaults to 117.0
        imageStd: 255.0, // defaults to 1.0
        numResults: 2, // defaults to 5
        threshold: 0.2, // defaults to 0.1
        asynch: true // defaults to true
        );
    setState(() {
      _loading = false;
      _outputs = output;
    });
  }

  Future pickImage() async {
    final image = await ImagePicker().pickImage(source: ImageSource.gallery);

    if (image == null) return;

    setState(() {
      _loading = true;
      _image = File(image.path);
    });
    classifyImage(_image!);
  }

  @override
  void dispose() {
    Tflite.close();
    super.dispose();
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: const Text('TFLite Image Classification'),
      ),
      body: _loading
          ? const Center(
              child: CircularProgressIndicator(),
            )
          : Container(
              child: Column(
                crossAxisAlignment: CrossAxisAlignment.center,
                mainAxisAlignment: MainAxisAlignment.center,
                children: [
                  _image == null
                      ? const Text('No image selected.')
                      : SizedBox(
                          height: 300,
                          child: Image.file(_image!),
                        ),
                  const SizedBox(
                    height: 20,
                  ),
                  _outputs != null
                      ? Text(
                          '${_outputs![0]["label"]}'.substring(2),
                          style: const TextStyle(
                              color: Colors.black,
                              fontSize: 20.0,
                              background: const Color.fromRGBO(
                                  255, 255, 255, 0.8)),
                        )
                      : const Text(''),
                ],
              ),
            ),
      floatingActionButton: FloatingActionButton(
        onPressed: pickImage,
        tooltip: 'Pick Image',
        child: const Icon(Icons.image),
      ),
    );
  }
}

Explanation:

  • Dependencies: Ensure you have the image_picker and tflite dependencies in your pubspec.yaml file.
  • Model Loading: Load the TensorFlow Lite model in the initState method using Tflite.loadModel.
  • Image Classification: Use Tflite.runModelOnImage to classify the selected image. The function takes the image path and optional parameters for mean, standard deviation, and confidence threshold.
  • Result Display: Display the classification result using a Text widget.
  • Image Picker: Allow the user to pick an image from the gallery using the image_picker plugin.

Step 5: Add Image Picker Dependency

You’ll also need the image_picker package to allow users to select images from their device. Add this dependency to your pubspec.yaml file:

dependencies:
  flutter:
    sdk: flutter
  image_picker: ^0.8.4+4 # Use the latest version

Run flutter pub get again to install it.

Best Practices

  • Optimize Models: Use model quantization and pruning techniques to reduce the size and improve the performance of your TensorFlow Lite models.
  • Error Handling: Implement proper error handling to catch exceptions that may occur during model loading or inference.
  • Background Processing: Perform ML inference in a background isolate to prevent blocking the main UI thread and ensure a smooth user experience.
  • Device Compatibility: Test your application on a variety of devices to ensure compatibility and performance.
  • User Feedback: Provide visual feedback (e.g., progress indicators) during model loading and inference to keep the user informed.

Conclusion

Integrating machine learning models into Flutter applications with TensorFlow Lite opens up exciting possibilities for creating intelligent and interactive user experiences. By following the steps outlined in this blog post, you can seamlessly integrate TFLite models into your Flutter projects and leverage the power of on-device machine learning. This enables you to create applications that are not only fast and efficient but also respect user privacy by processing data locally.