CIFAR10 TFLM快速搭建指南 目录
一、概述 二、环境准备 2.1 虚拟环境 2.2 使用Google Colab 三、核心步骤 3.1 CIFAR10数据集 3.2 模型创建 3.3 模型训练 3.4 模型转换 3.5 推理验证 3.6 benchmark性能 3.7 完整实现 3.7 简化版本(Colab) 四、TFLite Micro部署 五、快速验证 六、应用示例 七、结论 八、参考
Spoiler (Highlight to read) more details, please see the attachment. more details, please see the attachment.
一、概述
CIFAR-10: 多伦多大学Alex Krizhevsky CIFAR-10公开数据集,也是计算机视觉领域最经典、最常用的入门级基准数据集之一,包含10个类别的6万张32x32彩色图像(5万训练,1万测试),例如飞机、汽车、鸟、猫等。
tflm_cifar10:演示了如何在恩智浦的微控制器上使用TensorFlow Lite Micro框架,实时运行CIFAR-10图像分类模型。即将一个预先训练好的、针对CIFAR-10数据集的卷积神经网络模型部署到MCU上,让其具备了识别10类常见物体(飞机、汽车、鸟、猫等)的能力。
模型:一个轻量级CNN模型,包含3个卷积层、ReLU激活层、池化层和一个全连接层。
输入: 32x32像素的彩色图像。
输出:图像属于CIFAR-10中10个类别的概率。
本文档:提供针对CIFAR10数据集搭建的完整流程,从数据集、模型训练转换、部署推理的快速实现方案,可作为示例tflm_cifar10(推理为主)的前置补充,本文不涉及到端侧的部署与优化。
...
"""
CIFAR10 快速训练、测试、部署与推理完整流程
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 减少TensorFlow日志
import tensorflow as tf
import numpy as np
import time
import matplotlib.pyplot as plt
print(f"TensorFlow版本: {tf.__version__}")
print(f"NumPy版本: {np.__version__}")
class CIFAR10QuickPipeline:
def __init__(self):
"""初始化管道"""
self.model = None
self.tflite_model = None
def load_data(self, sample_size=1000):
"""加载简化数据集"""
print("\n1. 加载CIFAR10数据集...")
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# 预处理
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# 使用少量数据(快速训练)
x_train_small = x_train[:sample_size]
y_train_small = y_train[:sample_size]
x_test_small = x_test[:200]
y_test_small = y_test[:200]
# 转换为独热编码
y_train_onehot = tf.keras.utils.to_categorical(y_train_small, 10)
y_test_onehot = tf.keras.utils.to_categorical(y_test_small, 10)
print(f"训练数据: {x_train_small.shape}")
print(f"测试数据: {x_test_small.shape}")
return (x_train_small, y_train_onehot), (x_test_small, y_test_onehot)
def create_simple_model(self):
"""创建简化CNN模型"""
print("\n2. 创建简单CNN模型...")
model = tf.keras.Sequential([
# 输入层
tf.keras.layers.Input(shape=(32, 32, 3)),
# 卷积层1
tf.keras.layers.Conv2D(8, (3, 3), padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
# 卷积层2
tf.keras.layers.Conv2D(16, (3, 3), padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
# 全连接层
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.summary()
self.model = model
return model
def train_model(self, x_train, y_train, x_test, y_test, epochs=10):
"""训练模型"""
print("\n3. 训练模型...")
# 回调函数:早停
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=3,
restore_best_weights=True
)
]
history = self.model.fit(
x_train, y_train,
epochs=epochs,
batch_size=32,
validation_data=(x_test, y_test),
callbacks=callbacks,
verbose=1
)
# 评估模型
test_loss, test_acc = self.model.evaluate(x_test, y_test, verbose=0)
print(f"\n测试准确率: {test_acc:.4f}")
return history
def convert_to_tflite(self):
"""转换为TFLite格式"""
print("\n4. 转换为TFLite格式...")
# 转换为TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(self.model)
# 优化配置
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float32]
# 转换
tflite_model = converter.convert()
# 保存模型
with open('cifar10_model.tflite', 'wb') as f:
f.write(tflite_model)
# 保存为字节数组(用于嵌入式部署)
self.save_as_c_array(tflite_model)
self.tflite_model = tflite_model
model_size = len(tflite_model) / 1024
print(f"模型大小: {model_size:.1f} KB")
return tflite_model
def save_as_c_array(self, tflite_model):
"""保存为C数组格式"""
c_array = '// 自动生成的CIFAR10模型数组\n'
c_array += '#include \n\n'
c_array += 'const unsigned char cifar10_model_tflite[] = {\n'
# 每行显示12个字节
for i in range(0, len(tflite_model), 12):
line_bytes = tflite_model[i:i+12]
c_array += ' ' + ', '.join(f'0x{b:02x}' for b in line_bytes) + ',\n'
c_array += '};\n\n'
c_array += f'const unsigned int cifar10_model_tflite_len = {len(tflite_model)};\n'
with open('cifar10_model_array.h', 'w') as f:
f.write(c_array)
print("C数组已保存: cifar10_model_array.h")
def test_tflite_inference(self, x_test, y_test, num_tests=10):
"""测试TFLite推理"""
print(f"\n5. 测试TFLite推理 ({num_tests}个样本)...")
if self.tflite_model is None:
with open('cifar10_model.tflite', 'rb') as f:
self.tflite_model = f.read()
# 加载TFLite模型
interpreter = tf.lite.Interpreter(model_content=self.tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# 类别名称
class_names = ['飞机', '汽车', '鸟', '猫', '鹿',
'狗', '青蛙', '马', '船', '卡车']
correct = 0
times = []
for i in range(min(num_tests, len(x_test))):
# 准备输入
input_data = x_test[i:i+1]
# 推理
start_time = time.perf_counter()
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
inference_time = time.perf_counter() - start_time
times.append(inference_time)
# 获取输出
output = interpreter.get_tensor(output_details[0]['index'])
predicted_class = np.argmax(output[0])
actual_class = np.argmax(y_test[i])
# 检查是否正确
if predicted_class == actual_class:
correct += 1
print(f"样本 {i+1}: 预测={class_names[predicted_class]:<5} "
f"实际={class_names[actual_class]:<5} "
f"时间={inference_time*1000:.1f}ms "
f"{'✓' if predicted_class == actual_class else '✗'}")
accuracy = correct / num_tests
avg_time = np.mean(times) * 1000
print(f"\n推理统计:")
print(f" 准确率: {accuracy:.1%} ({correct}/{num_tests})")
print(f" 平均推理时间: {avg_time:.1f}ms")
print(f" 推理速度: {1000/avg_time:.0f} FPS")
return accuracy, avg_time
def benchmark_performance(self, x_test):
"""性能基准测试"""
print("\n6. 性能基准测试...")
interpreter = tf.lite.Interpreter(model_content=self.tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
# 预热
test_input = x_test[0:1]
for _ in range(10):
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
# 基准测试
num_runs = 100
start_time = time.perf_counter()
for _ in range(num_runs):
interpreter.invoke()
total_time = time.perf_counter() - start_time
avg_time = total_time / num_runs * 1000
print(f"基准测试结果:")
print(f" 总推理次数: {num_runs}")
print(f" 总时间: {total_time*1000:.1f}ms")
print(f" 平均推理时间: {avg_time:.1f}ms")
print(f" 推理速度: {1000/avg_time:.0f} FPS")
return avg_time
def save_model_summary(self):
"""保存模型摘要"""
summary = []
self.model.summary(print_fn=lambda x: summary.append(x))
with open('model_summary.txt', 'w') as f:
f.write('\n'.join(summary))
f.write(f"\n\n模型信息:")
f.write(f"\n参数数量: {self.model.count_params():,}")
f.write(f"\n保存时间: {time.ctime()}")
print("模型摘要已保存: model_summary.txt")
def main():
"""主函数"""
print("=" * 60)
print("CIFAR10 快速训练、测试、部署管道")
print("=" * 60)
# 创建管道
pipeline = CIFAR10QuickPipeline()
# 1. 加载数据
(x_train, y_train), (x_test, y_test) = pipeline.load_data(sample_size=2000)
# 2. 创建模型
pipeline.create_simple_model()
# 3. 训练模型
history = pipeline.train_model(x_train, y_train, x_test, y_test, epochs=15)
# 4. 保存模型摘要
pipeline.save_model_summary()
# 5. 转换为TFLite
pipeline.convert_to_tflite()
# 6. 测试推理
pipeline.test_tflite_inference(x_test, y_test, num_tests=20)
# 7. 性能测试
pipeline.benchmark_performance(x_test)
print("\n" + "=" * 60)
print("流程完成!生成的文件:")
print(" 1. cifar10_model.tflite - TFLite模型")
print(" 2. cifar10_model_array.h - C数组格式")
print(" 3. model_summary.txt - 模型摘要")
print("=" * 60)
if __name__ == "__main__":
main()
...
七、结论
本文旨在以常见图像分类场景(CIFAR10)为例,让读者快速了解从数据搭建、模型创建、训练、推理和验证的完整流程,可作为示例tflm_cifar10(端推理为主)的前置补充,本文不涉及到端侧部署与优化。
記事全体を表示