스터디/Android+Kotlin

Android Kotlin ONNX 연동 (MobileFace 모델)

Dalmangyi 2021. 10. 12.

ONNX(오픈 뉴럴 네트워크 익스체인지)

기계 학습이나 딥러닝 모델을 공통의 연산자 집합으로 바꿔서 여러 프레임워크와 컴파일러에서 사용할 수 있도록 해주는 표준을 말합니다.
이번 게시글에서는 MobileFace 모델을 안드로이드 기반에서 구동될 수 있도록 변환하기 위해 사용되었습니다.

 

MobileFace

한정된 모바일 환경에서 적은 리소스를 사용하여 얼굴을 찾는 모델입니다.

오픈 모델이 항상 그러하듯, 성능의 한계가 있습니다.
1. 얼굴 각도가 땅을 보고 있을때 인식이 잘 안됩니다. (face 7)
2. 어두운 조명에 있는 얼굴은 인식이 잘 안됩니다. (face 12, face14)
3. 흑인을 잘 인식하지 못 하고, 깜빡입니다. (face 9)
4. 같은 사람임에도 불구하고 연속적인 얼굴로 인지 하지 못합니다 (face 8 -> face 13)

그래도 이러한 인식은 추가적인 개발로 방지할 수 있습니다.
모바일에서 중요한건 얼마나 빠른 연산을 하냐는것이 중요합니다. 
MobileFace_Identification_V3에서는 정말 빠른 연산 처리가 가능했습니다.

MobileFace 사용예시

 

 

 

ONNX로 변환한 MobileFace

raw 폴더에 추가해 줍니다.

mobileface.onnx
4.25MB

 

 

라이브러리

libs 폴더에 오닉스 런타임(onnxruntime-release.aar) 파일을 추가합니다.
오닉스 런타임은 오닉스 형식으로 변환된 모델을 읽을 수 있게 도와줍니다.
이 게시글에서는 ONNX로 변환된 MobileFace 모델을 읽기 위해서 사용됩니다.

onnxruntime-release.aar
2.80MB

 

 

build.gradle

//직접 추가한 dependency를 추가해 줍니다.
implementation(name: "onnxruntime-release", ext: "aar")

 

 

NNAPI (Android Neural Networks API)

안드로이드 기기에서 하드웨어 가속 추론 작업을 하기 위해 만들어진 C API 입니다.
Android 11 이상 실행이 가능한 기기에서 사용할 수 있습니다.

 

 

 

얼굴 이미지(Bitmap)에서 특징 점(Feature Float Array) 추출 (Kotlin)

object FROnnxMobileNet {

    //안드로이드 가속화
    private var enableNNAPI: Boolean = false

    //오닉스 런타임 (ONNX)
    private var ortEnv: OrtEnvironment? = null
    private var ortSession:OrtSession? = null


    const val IMAGE_MEAN: Float = .0f
    const val IMAGE_STD: Float = 255f
    const val DIM_BATCH_SIZE = 1
    const val DIM_PIXEL_SIZE = 3
    const val IMAGE_SIZE_X = 128
    const val IMAGE_SIZE_Y = 128

    init {
        ortEnv = OrtEnvironment.getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL)
        ortSession = CreateOrtSession()
    }

    private fun CreateOrtSession(): OrtSession? {
        val so = OrtSession.SessionOptions()
        so.use {
            // Set to use 2 intraOp threads for CPU EP
            so.setIntraOpNumThreads(2)

            if (enableNNAPI)
                so.addNnapi()

            return ortEnv?.createSession(readModel(), so)
        }
    }

    private fun readModel(): ByteArray {
        val res = CovidApplication.instance.resources
        return res.openRawResource(R.raw.mobileface).readBytes()
    }



    fun preprocess2(bitmap: Bitmap): FloatBuffer {
        val imgData = FloatBuffer.allocate(
            DIM_BATCH_SIZE
                    * IMAGE_SIZE_X
                    * IMAGE_SIZE_Y
                    * DIM_PIXEL_SIZE)
        imgData.rewind()

        val bmpData = IntArray(IMAGE_SIZE_X * IMAGE_SIZE_Y)
        bitmap.getPixels(bmpData, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)

        var idx: Int = 0
        for (i in 0..IMAGE_SIZE_X - 1) {
            for (j in 0..IMAGE_SIZE_Y - 1) {
                val pixelValue = bmpData[idx++]
                imgData.put(((pixelValue shr 16 and 0xFF) - IMAGE_MEAN) / IMAGE_STD)
                imgData.put(((pixelValue shr 8 and 0xFF) - IMAGE_MEAN) / IMAGE_STD)
                imgData.put(((pixelValue and 0xFF) - IMAGE_MEAN) / IMAGE_STD)
            }
        }

        imgData.rewind()
        return imgData
    }

    fun analyze(bitmap:Bitmap) : FloatArray? {
        val resizeBitmap = Bitmap.createScaledBitmap(bitmap, 128, 128, false)
        val imgData = preprocess2(resizeBitmap) //128x128 => 49152

        val inputName = ortSession?.inputNames?.iterator()?.next()

        val shape = longArrayOf(1, 3, 128, 128)
//            val shape = longArrayOf(3, 224, 224)
        val ortEnv = OrtEnvironment.getEnvironment()
        ortEnv.use {
            // Create input tensor
            val input_tensor = OnnxTensor.createTensor(ortEnv, imgData, shape)
            val startTime = SystemClock.uptimeMillis()
            input_tensor.use {
                // Run the inference and get the output tensor
                val output = ortSession?.run(Collections.singletonMap(inputName, input_tensor))
//                    val output = ortSession?.run(Collections.singletonMap("input", input_tensor))
                output.use {

                    //reshape(-1)
                    val output0 = (output?.get(0)?.value) as Array<Array<Array<FloatArray>>>
                    val output1 = output0[0].flatten()
                    output1[0].forEach { }
                    val output2 = FloatArray(output1.size)
                    output1.forEachIndexed { index, floats ->
                        floats.forEach {
                            output2[index] = it
                        }
                    }

                    output.close()

                    return output2
                }
            }
        }
    }

}

 

 

 

 

 

 

 

댓글