Notice
Recent Posts
Recent Comments
Link
아는 만큼 보인다
Tensorflow SavedModel 포맷 모델을 Tensorflow Lite (TFLite)로 변환할 때 input_signature 오류 본문
머신러닝&딥러닝
Tensorflow SavedModel 포맷 모델을 Tensorflow Lite (TFLite)로 변환할 때 input_signature 오류
계토 2023. 10. 13. 17:10tensorflow lite 소개, 변환, 추론 관련 내용은 여기에 정리되어 있다.
SavedModel 형식과 관련된 공식 문서는 이곳이다.
그런데 이게 사용해보니, 처음에 SavedModel 형식으로 모델을 저장할 때 input_signature라는 것을 지정해줘야 했다.
즉 batch size와 input length 등 input meta 정보를 같이 저장해주어야 했다. 그냥 모델을 저장하니 에러가 나서, 구글링 끝에 얻은 코드는 다음과 같다.
아래는 TFLite로 변환하기 전, 즉 model을 pb 모델로 제대로 저장하기 위한 코드이다.
def save_SavedModel(model, batch, input_len, save_path):
class MyModule(tf.Module):
def __init__(self, model):
self.model = model
@tf.function(input_signature=[tf.TensorSpec([batch, input_len, 1], tf.float32)])
def predict(self, input_data):
return self.model(input_data)
module = MyModule(model)
tf.saved_model.save(module, save_path, signatures={"serving_default": module.predict})
우선 model을 module object로 만들어주고, predict라는 function을 정의해준다. 이 때 @tf.function을 통해 input signature를 지정해준다. 이후 이 모듈 자체를 저장해주고, 이때 signature를 지정하며 model.predict 사용. 이렇게 저장하면 pb에서 tflite로 에러없이 잘 변환된다.