《Keras 3 :使用迁移学习进行关键点检测》
作者:Sayak Paul,由 Muhammad Anas Raza
转换为 Keras 3 创建日期:2021/05/02
最后修改时间:2023/07/19
描述:使用数据增强和迁移学习训练关键点检测器。
在 Colab 中查看
GitHub 源
关键点检测包括定位关键对象部分。例如,关键部分 的脸包括鼻尖、眉毛、眼角等。这些部件有助于 以功能丰富的方式表示底层对象。关键点检测具有 包括姿势估计、人脸检测等的应用程序。
在此示例中,我们将使用 StanfordExtra 数据集 StanfordExtra 构建一个关键点检测器 使用迁移学习。此示例需要 TensorFlow 2.4 或更高版本, 以及 Imgaug 图书馆, 可以使用以下命令进行安装:
!pip install -q -U imgaug
数据采集
StanfordExtra 数据集包含 12,000 张狗图像以及关键点和 分割图。它是从 Stanford dogs 数据集开发的。 可以使用以下命令下载它:
!wget -q http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar
注释在 StanfordExtra 数据集中以单个 JSON 文件的形式提供,并且需要 填写此表单以访问它。这 作者明确指示用户不要共享 JSON 文件,此示例尊重此愿望: 您应该自己获取 JSON 文件。
JSON 文件应在本地以 .stanfordextra_v12.zip
下载文件后,我们可以提取档案。
!tar xf images.tar
!unzip -qq ~/stanfordextra_v12.zip
进口
from keras import layers
import keras
from imgaug.augmentables.kps import KeypointsOnImage
from imgaug.augmentables.kps import Keypoint
import imgaug.augmenters as iaa
from PIL import Image
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import json
import os
定义超参数
IMG_SIZE = 224
BATCH_SIZE = 64
EPOCHS = 5
NUM_KEYPOINTS = 24 * 2 # 24 pairs each having x and y coordinates
加载数据
作者还提供了一个元数据文件,该文件指定了有关 关键点,如颜色信息、动物姿势名称等。我们将此文件加载到 DataFrame 中,以提取用于可视化目的的信息。pandas
IMG_DIR = "Images"
JSON = "StanfordExtra_V12/StanfordExtra_v12.json"
KEYPOINT_DEF = (
"https://github.com/benjiebob/StanfordExtra/raw/master/keypoint_definitions.csv"
)
# Load the ground-truth annotations.
with open(JSON) as infile:
json_data = json.load(infile)
# Set up a dictionary, mapping all the ground-truth information
# with respect to the path of the image.
json_dict = {
i["img_path"]: i for i in json_data}
的单个条目如下所示:json_dict
'n02085782-Japanese_spaniel/n02085782_2886.jpg':
{'img_bbox': [205, 20, 116, 201],
'img_height': 272,
'img_path': 'n02085782-Japanese_spaniel/n02085782_2886.jpg',
'img_width': 350,
'is_multiple_dogs': False,
'joints': [[108.66666666666667, 252.0, 1],
[147.66666666666666, 229.0, 1],
[163.5, 208.5, 1],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[54.0, 244.0, 1],
[77.33333333333333, 225.33333333333334, 1],
[79.0, 196.5, 1],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[150.66666666666666, 86.66666666666667, 1],
[88.66666666666667, 73.0, 1],
[116.0, 106.33333333333333, 1],
[109.0, 123.33333333333333, 1],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
'seg': ...}
在此示例中,我们感兴趣的键是:
img_path
joints
里面总共有 24 个条目。每个条目有 3 个值:joints
- x 坐标
- y 坐标
- 关键点的可见性标志(1 表示可见性,0 表示不可见)
正如我们所看到的,包含多个条目,这些条目表示这些 关键点没有标记。在此示例中,我们将考虑 non-visible 和 未标记的关键点,以便进行小批量学习。joints
[0, 0, 0]
# Load the metdata definition file and preview it.
keypoint_def = pd.read_csv(KEYPOINT_DEF)
keypoint_def.head()
# Extract the colours and labels.
colours = keypoint_def["Hex colour"].values.tolist()
colours = ["#" + colour for colour in colours]
labels = keypoint_def["Name"].values.tolist()
# Utility for reading an image and for getting its annotations.
def get_dog(name