基于人工智能的Android物体识别软件

Tensorflow神经网络学习

Posted by yourzeromax on May 15, 2018

简介

前几天,参加了Google西南联盟举办的InnoCamp创新实践营,学习了一些关于Tensorflow的知识,也亲手训练了两个人工智能(zhang)出来,一个是手写字数字识别,一个就是物体识别啦~虽然很多东西都是google tensorflow官方网站开源过的,但是这篇文章所介绍的程序是自己利用了OpenCV和Tensorflow进行封装的类,有很强的适用性,可以方便地移植工程中的两个类,这也是本文重点介绍的内容。项目内容可以查看:
test_recogniton·GitHub

工程配置

主要分为Android版本配置,和OpenCV及Tensorflow版本配置

Android版本配置

Android SDK版本必须大于等于21

  • 编译版本(compileSdkVersion) 27
  • 最小版本(minSdkVersion) 24
  • 目标版本(targetSdkVersion)27

  • Gradle插件版本 3.1.2
  • Gradle版本 4.4

其他版本配置

OpenCV 3.20
Tensorflow 1.7
NDK 17.0

神经网络模型库

工程结构

如图所示,主要代码分为三个类,其中MainActivity是Andorid项目工程,OpenCVManager和TFManager分别是针对OpenCV和Tensorflow解耦出来的类,其中TFManager对象采用单例模式并持有OpenCVManager对象(如果要用于大型项目中的话,采用单例模式持有的形式容易造成内存泄漏,可以根据实际情况修改)。

MainActivity.class

public class MainActivity extends Activity implements CvCameraViewListener2 {

    static {
        System.loadLibrary(TFManager.LIBRARYLOCATION);
    }

    private CameraBridgeViewBase mOpenCvCameraView;     //将传递给TFManager对象

    private TFManager mTFManager;

    @Override
    public void onCreate(Bundle savedInstanceState) {
      //  Log.i(TAG, "called onCreate");
        super.onCreate(savedInstanceState);
        getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
        setContentView(R.layout.activity_main);
        if (checkSelfPermission(Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
            //申请权限
            requestPermissions(new String[]{Manifest.permission.CAMERA}, 100);
        }

        initViews();
        initWithTensorflow();    //由于要传递OpenCvCameraView对象,所以initWithTensorflow必须在initViews后面执行
    }

    private void initViews(){
        mOpenCvCameraView = (CameraBridgeViewBase) findViewById(R.id.tutorial1_activity_java_surface_view);
        mOpenCvCameraView.setVisibility(SurfaceView.VISIBLE);
        mOpenCvCameraView.setCvCameraViewListener(this);                                    //设置Listener
        mOpenCvCameraView.setMaxFrameSize(640,640);
    }

    private void initWithTensorflow(){
        mTFManager = TFManager.getInstance(this);
        mTFManager.setCameraBridgeViewBase(mOpenCvCameraView);
        mTFManager.setOpenCVCallback(new BaseLoaderCallback(this) {                 //OpenCV回调函数对象的初始化
            @Override
            public void onManagerConnected(int status) {
                switch (status) {
                    case LoaderCallbackInterface.SUCCESS:
                    {
                        //   Log.i(TAG, "OpenCV loaded successfully");
                        mOpenCvCameraView.enableView();
                    } break;
                    default:
                    {
                        super.onManagerConnected(status);
                    } break;
                }
            }
        });
        mTFManager.init();
    }


    @Override
    public void onResume()
    {
        super.onResume();
        mTFManager.resume();
    }

    @Override
    public void onPause()
    {
        super.onPause();
       mTFManager.pause();
    }

    public void onDestroy() {
        super.onDestroy();
        mTFManager.stop();
    }


    public void onCameraViewStarted(int width, int height) {
    }

    public void onCameraViewStopped() {
    }


    //CvCameraViewListener2接口函数
    public Mat onCameraFrame(CvCameraViewFrame inputFrame) {
        Mat img_rgb = inputFrame.rgba();
        Mat img_t = new Mat();
        Mat img_gray = new Mat();
        Mat img_contours;

        Core.transpose(img_rgb,img_t);//转置函数,可以水平的图像变为垂直
        Imgproc.resize(img_t, img_rgb, img_rgb.size(), 0.0D, 0.0D, 0);
        Core.flip(img_rgb, img_rgb,1);  //flipCode>0将mRgbaF水平翻转(沿Y轴翻转)得到mRgba

        if(img_rgb != null) {
            Imgproc.cvtColor(img_rgb, img_gray, Imgproc.COLOR_RGB2GRAY);

            Imgproc.threshold(img_gray, img_gray, 140, 255, Imgproc.THRESH_BINARY_INV);

            //像素加强
            Mat ele1 = Imgproc.getStructuringElement(Imgproc.MORPH_RECT, new Size(3, 3));
            Mat ele2 = Imgproc.getStructuringElement(Imgproc.MORPH_RECT, new Size(6, 6));
            Imgproc.erode(img_gray, img_gray, ele1);
            Imgproc.dilate(img_gray, img_gray, ele2);

            //找到外界矩形
            img_contours = img_gray.clone();
            List<MatOfPoint> contours = new ArrayList<>();
            Imgproc.findContours(img_contours, contours, new Mat(),
                    Imgproc.RETR_LIST, Imgproc.CHAIN_APPROX_SIMPLE);

            for (int contourIdx = 0; contourIdx < contours.size(); contourIdx++) {
                double contourArea = Imgproc.contourArea(contours.get(contourIdx));
                Rect rect = Imgproc.boundingRect(contours.get(contourIdx));
                if (contourArea < 1500 || contourArea > 20000)
                    continue;

                Mat roi = new Mat(img_gray, rect);
                Imgproc.resize(roi, roi, new Size(28, 28));

               mTFManager.detector(img_rgb);
            }
            img_contours.release();
        }


        img_gray.release();
        img_t.release();
        img_t.release();
        return  img_rgb;
    }
}

需要指出来的是,涉及到摄像头权限管理,并且在MainActivity应该实现CvCameraViewListener2接口:

函数 描述
public void onCameraViewStarted(int width, int height); 开始时调用
public void onCameraViewStopped(); 停用时调用
public Mat onCameraFrame(CvCameraViewFrame inputFrame) 返回一个Mat对象

设置回调对象:

mOpenCvCameraView.setCvCameraViewListener(this);

mTFManager.setOpenCVCallback(new BaseLoaderCallback(this) {                 //OpenCV回调函数对象的初始化
            @Override
            public void onManagerConnected(int status) {
                switch (status) {
                    case LoaderCallbackInterface.SUCCESS:
                    {
                        //   Log.i(TAG, "OpenCV loaded successfully");
                        mOpenCvCameraView.enableView();
                    } break;
                    default:
                    {
                        super.onManagerConnected(status);
                    } break;
                }
            }
        });

TFManager.class

在这个类别中,主要封装了相应的模型路径,和一些OpenCVManager会用到的参数,一般会随着主活动的生命周期调用相关的方法,看代码便知,最为关键的方法是对采集图像后,进行Session的提交,如下:

void detector(Mat img_rgb){

        if(tensorFlowInferenceInterface == null){
            return ;}

        Mat img = new Mat();
        Imgproc.resize(img_rgb, img, new Size(inputSize,inputSize));
        Bitmap bitmap = Bitmap.createBitmap(img.width(), img.height(), Bitmap.Config.ARGB_8888);
        Utils.matToBitmap(img, bitmap);

        int[] intValues = new int[img.width() * img.height()];
        byte[] byteValues = new byte[img.width() * img.height() * 3];
        bitmap.getPixels(intValues, 0, bitmap.getWidth(),
                0, 0, bitmap.getWidth(), bitmap.getHeight());

        for (int i = 0; i < intValues.length; ++i) {
            final int val = intValues[i];
            byteValues[i * 3] = (byte) ((val >> 16) & 0xFF);
            byteValues[i * 3 + 1] = (byte) ((val >> 8) & 0xFF);
            byteValues[i * 3 + 2] = (byte) ((val & 0xFF));
        }

        Trace.beginSection("feed");
        tensorFlowInferenceInterface.feed(input_name, byteValues, 1, img.width(), img.height(), 3);
        Trace.endSection();

        Trace.beginSection("run");
        tensorFlowInferenceInterface.run(output_names, logStats);
        Trace.endSection();

        Trace.beginSection("fetch");
        tensorFlowInferenceInterface.fetch(output_names[0], outputLocations);
        tensorFlowInferenceInterface.fetch(output_names[1], outputScores);
        tensorFlowInferenceInterface.fetch(output_names[2], outputClasses);
        tensorFlowInferenceInterface.fetch(output_names[3], outputNumDetections);
        Trace.endSection();


        for (int i = 0; i < outputScores.length; ++i) {

            if(outputScores[i] > 0.6) {
                Point p1 = new Point(outputLocations[4 * i + 1] * inputSize,
                        outputLocations[4 * i] * inputSize);
                Point p2 = new Point(outputLocations[4 * i + 3] * inputSize,
                        outputLocations[4 * i + 2] * inputSize);

                final Rect detection = new Rect(p1, p2);

                Imgproc.rectangle(img, p1,p2,
                        new Scalar(0,0,255), 2);
                Imgproc.putText(img, labels.get((int) outputClasses[i]), p1, Core.FONT_HERSHEY_DUPLEX,
                        1, new Scalar(0, 255, 0), 1);
                Log.d(TAG, "            " + outputClasses[i]);
            }
        }

        Imgproc.resize(img, img_rgb, img_rgb.size());
        img.release();
    }

OpenCVManager.class

该类主要是接受回调对象,以及设置了相应的生命周期,随主活动而调用。

最终效果

其他说明

物体识别需要有良好的Tensorflow 和OpenCV的知识,本Demo是一个很好的入门程序,虽然Android开发工程师不太需要人工智能方面的知识,但是Android一定会是人工智能的重要设备载体,因此了解一些Tensorflow的知识还是特别有必要的。

欢迎关注我的博客