Skip to content
Snippets Groups Projects
Unverified Commit 83b3c386 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Move ScanNet point alignment from data pre-processing to pipeline (#439)

parent 9c9a86e4
Branches
No related tags found
No related merge requests found
Showing
with 456 additions and 105 deletions
......@@ -18,6 +18,7 @@ train_pipeline = [
with_label_3d=True,
with_mask_3d=True,
with_seg_3d=True),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
......@@ -49,6 +50,7 @@ test_pipeline = [
shift_height=True,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
......@@ -82,6 +84,7 @@ eval_pipeline = [
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
......
......@@ -34,9 +34,10 @@ def export_one_scan(scan_name,
scan_name + '_vh_clean_2.0.010000.segs.json')
# includes axisAlignment info for the train set scans.
meta_file = osp.join(scannet_dir, scan_name, f'{scan_name}.txt')
mesh_vertices, semantic_labels, instance_labels, instance_bboxes, \
instance2semantic = export(mesh_file, agg_file, seg_file,
meta_file, label_map_file, None, test_mode)
mesh_vertices, semantic_labels, instance_labels, unaligned_bboxes, \
aligned_bboxes, instance2semantic, axis_align_matrix = export(
mesh_file, agg_file, seg_file, meta_file, label_map_file, None,
test_mode)
if not test_mode:
mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS))
......@@ -47,9 +48,12 @@ def export_one_scan(scan_name,
num_instances = len(np.unique(instance_labels))
print(f'Num of instances: {num_instances}')
bbox_mask = np.in1d(instance_bboxes[:, -1], OBJ_CLASS_IDS)
instance_bboxes = instance_bboxes[bbox_mask, :]
print(f'Num of care instances: {instance_bboxes.shape[0]}')
bbox_mask = np.in1d(unaligned_bboxes[:, -1], OBJ_CLASS_IDS)
unaligned_bboxes = unaligned_bboxes[bbox_mask, :]
bbox_mask = np.in1d(aligned_bboxes[:, -1], OBJ_CLASS_IDS)
aligned_bboxes = aligned_bboxes[bbox_mask, :]
assert unaligned_bboxes.shape[0] == aligned_bboxes.shape[0]
print(f'Num of care instances: {unaligned_bboxes.shape[0]}')
if max_num_point is not None:
max_num_point = int(max_num_point)
......@@ -65,7 +69,11 @@ def export_one_scan(scan_name,
if not test_mode:
np.save(f'{output_filename_prefix}_sem_label.npy', semantic_labels)
np.save(f'{output_filename_prefix}_ins_label.npy', instance_labels)
np.save(f'{output_filename_prefix}_bbox.npy', instance_bboxes)
np.save(f'{output_filename_prefix}_unaligned_bbox.npy',
unaligned_bboxes)
np.save(f'{output_filename_prefix}_aligned_bbox.npy', aligned_bboxes)
np.save(f'{output_filename_prefix}_axis_align_matrix.npy',
axis_align_matrix)
def batch_export(max_num_point,
......
......@@ -52,6 +52,24 @@ def read_segmentation(filename):
return seg_to_verts, num_verts
def extract_bbox(mesh_vertices, object_id_to_segs, object_id_to_label_id,
instance_ids):
num_instances = len(np.unique(list(object_id_to_segs.keys())))
instance_bboxes = np.zeros((num_instances, 7))
for obj_id in object_id_to_segs:
label_id = object_id_to_label_id[obj_id]
obj_pc = mesh_vertices[instance_ids == obj_id, 0:3]
if len(obj_pc) == 0:
continue
xyz_min = np.min(obj_pc, axis=0)
xyz_max = np.max(obj_pc, axis=0)
bbox = np.concatenate([(xyz_min + xyz_max) / 2.0, xyz_max - xyz_min,
np.array([label_id])])
# NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES
instance_bboxes[obj_id - 1, :] = bbox
return instance_bboxes
def export(mesh_file,
agg_file,
seg_file,
......@@ -69,7 +87,7 @@ def export(mesh_file,
label_map_file (str): Path of the label_map_file.
output_file (str): Path of the output folder.
Default: None.
test_mode (bool): Whether is generating training data without labels.
test_mode (bool): Whether is generating test data without labels.
Default: False.
It returns a tuple, which containts the the following things:
......@@ -86,8 +104,7 @@ def export(mesh_file,
# Load scene axis alignment matrix
lines = open(meta_file).readlines()
# TODO: test set data doesn't have align_matrix!
# TODO: save align_matrix and move align step to pipeline in the future
# test set data doesn't have align_matrix
axis_align_matrix = np.eye(4)
for line in lines:
if 'axisAlignment' in line:
......@@ -97,10 +114,13 @@ def export(mesh_file,
]
break
axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4))
# perform global alignment of mesh vertices
pts = np.ones((mesh_vertices.shape[0], 4))
pts[:, 0:3] = mesh_vertices[:, 0:3]
pts = np.dot(pts, axis_align_matrix.transpose()) # Nx4
mesh_vertices[:, 0:3] = pts[:, 0:3]
aligned_mesh_vertices = np.concatenate([pts[:, 0:3], mesh_vertices[:, 3:]],
axis=1)
# Load semantic and instance labels
if not test_mode:
......@@ -115,34 +135,21 @@ def export(mesh_file,
label_ids[verts] = label_id
instance_ids = np.zeros(
shape=(num_verts), dtype=np.uint32) # 0: unannotated
num_instances = len(np.unique(list(object_id_to_segs.keys())))
for object_id, segs in object_id_to_segs.items():
for seg in segs:
verts = seg_to_verts[seg]
instance_ids[verts] = object_id
if object_id not in object_id_to_label_id:
object_id_to_label_id[object_id] = label_ids[verts][0]
instance_bboxes = np.zeros((num_instances, 7))
for obj_id in object_id_to_segs:
label_id = object_id_to_label_id[obj_id]
obj_pc = mesh_vertices[instance_ids == obj_id, 0:3]
if len(obj_pc) == 0:
continue
xmin = np.min(obj_pc[:, 0])
ymin = np.min(obj_pc[:, 1])
zmin = np.min(obj_pc[:, 2])
xmax = np.max(obj_pc[:, 0])
ymax = np.max(obj_pc[:, 1])
zmax = np.max(obj_pc[:, 2])
bbox = np.array([(xmin + xmax) / 2, (ymin + ymax) / 2,
(zmin + zmax) / 2, xmax - xmin, ymax - ymin,
zmax - zmin, label_id])
# NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES
instance_bboxes[obj_id - 1, :] = bbox
unaligned_bboxes = extract_bbox(mesh_vertices, object_id_to_segs,
object_id_to_label_id, instance_ids)
aligned_bboxes = extract_bbox(aligned_mesh_vertices, object_id_to_segs,
object_id_to_label_id, instance_ids)
else:
label_ids = None
instance_ids = None
instance_bboxes = None
unaligned_bboxes = None
aligned_bboxes = None
object_id_to_label_id = None
if output_file is not None:
......@@ -150,10 +157,12 @@ def export(mesh_file,
if not test_mode:
np.save(output_file + '_sem_label.npy', label_ids)
np.save(output_file + '_ins_label.npy', instance_ids)
np.save(output_file + '_bbox.npy', instance_bboxes)
np.save(output_file + '_unaligned_bbox.npy', unaligned_bboxes)
np.save(output_file + '_aligned_bbox.npy', aligned_bboxes)
np.save(output_file + '_axis_align_matrix.npy', axis_align_matrix)
return mesh_vertices, label_ids, instance_ids, \
instance_bboxes, object_id_to_label_id
return mesh_vertices, label_ids, instance_ids, unaligned_bboxes, \
aligned_bboxes, object_id_to_label_id, axis_align_matrix
def main():
......
......@@ -129,12 +129,15 @@ class BaseInstance3DBoxes(object):
pass
@abstractmethod
def rotate(self, angles, axis=0):
"""Calculate whether the points are in any of the boxes.
def rotate(self, angle, points=None):
"""Rotate boxes with points (optional) with the given angle or \
rotation matrix.
Args:
angles (float): Rotation angles.
axis (int): The axis to rotate the boxes.
angle (float | torch.Tensor | np.ndarray):
Rotation angle or rotation matrix.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Points to rotate. Defaults to None.
"""
pass
......@@ -144,7 +147,7 @@ class BaseInstance3DBoxes(object):
pass
def translate(self, trans_vector):
"""Calculate whether the points are in any of the boxes.
"""Translate boxes with the given translation vector.
Args:
trans_vector (torch.Tensor): Translation vector of size 1x3.
......
......@@ -169,10 +169,12 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
return bev_boxes
def rotate(self, angle, points=None):
"""Rotate boxes with points (optional) with the given angle.
"""Rotate boxes with points (optional) with the given angle or \
rotation matrix.
Args:
angle (float, torch.Tensor): Rotation angle.
angle (float | torch.Tensor | np.ndarray):
Rotation angle or rotation matrix.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Points to rotate. Defaults to None.
......@@ -183,10 +185,20 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
"""
if not isinstance(angle, torch.Tensor):
angle = self.tensor.new_tensor(angle)
assert angle.shape == torch.Size([3, 3]) or angle.numel() == 1, \
f'invalid rotation angle shape {angle.shape}'
if angle.numel() == 1:
rot_sin = torch.sin(angle)
rot_cos = torch.cos(angle)
rot_mat_T = self.tensor.new_tensor([[rot_cos, 0, -rot_sin], [0, 1, 0],
rot_mat_T = self.tensor.new_tensor([[rot_cos, 0, -rot_sin],
[0, 1, 0],
[rot_sin, 0, rot_cos]])
else:
rot_mat_T = angle
rot_sin = rot_mat_T[2, 0]
rot_cos = rot_mat_T[0, 0]
angle = np.arctan2(rot_sin, rot_cos)
self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T
self.tensor[:, 6] += angle
......
......@@ -116,10 +116,12 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
return bev_boxes
def rotate(self, angle, points=None):
"""Rotate boxes with points (optional) with the given angle.
"""Rotate boxes with points (optional) with the given angle or \
rotation matrix.
Args:
angle (float, torch.Tensor): Rotation angle.
angle (float | torch.Tensor | np.ndarray):
Rotation angle or rotation matrix.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Points to rotate. Defaults to None.
......@@ -130,11 +132,21 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
"""
if not isinstance(angle, torch.Tensor):
angle = self.tensor.new_tensor(angle)
assert angle.shape == torch.Size([3, 3]) or angle.numel() == 1, \
f'invalid rotation angle shape {angle.shape}'
if angle.numel() == 1:
rot_sin = torch.sin(angle)
rot_cos = torch.cos(angle)
rot_mat_T = self.tensor.new_tensor([[rot_cos, -rot_sin, 0],
[rot_sin, rot_cos, 0], [0, 0,
1]]).T
[rot_sin, rot_cos, 0],
[0, 0, 1]]).T
else:
rot_mat_T = angle.T
rot_sin = rot_mat_T[0, 1]
rot_cos = rot_mat_T[0, 0]
angle = np.arctan2(rot_sin, rot_cos)
self.tensor[:, 0:3] = self.tensor[:, 0:3] @ rot_mat_T
if self.with_yaw:
self.tensor[:, 6] -= angle
......
......@@ -114,10 +114,12 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
return bev_boxes
def rotate(self, angle, points=None):
"""Rotate boxes with points (optional) with the given angle.
"""Rotate boxes with points (optional) with the given angle or \
rotation matrix.
Args:
angle (float | torch.Tensor): Rotation angle.
angles (float | torch.Tensor | np.ndarray):
Rotation angle or rotation matrix.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Points to rotate. Defaults to None.
......@@ -128,10 +130,20 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
"""
if not isinstance(angle, torch.Tensor):
angle = self.tensor.new_tensor(angle)
assert angle.shape == torch.Size([3, 3]) or angle.numel() == 1, \
f'invalid rotation angle shape {angle.shape}'
if angle.numel() == 1:
rot_sin = torch.sin(angle)
rot_cos = torch.cos(angle)
rot_mat_T = self.tensor.new_tensor([[rot_cos, -rot_sin, 0],
[rot_sin, rot_cos, 0], [0, 0, 1]])
[rot_sin, rot_cos, 0],
[0, 0, 1]])
else:
rot_mat_T = angle
rot_sin = rot_mat_T[1, 0]
rot_cos = rot_mat_T[0, 0]
angle = np.arctan2(rot_sin, rot_cos)
self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T
self.tensor[:, 6] += angle
......
......@@ -147,7 +147,7 @@ class BasePoints(object):
if not isinstance(rotation, torch.Tensor):
rotation = self.tensor.new_tensor(rotation)
assert rotation.shape == torch.Size([3, 3]) or \
rotation.numel() == 1
rotation.numel() == 1, f'invalid rotation shape {rotation.shape}'
if axis is None:
axis = self.rotation_axis
......
......@@ -7,7 +7,8 @@ from .kitti_mono_dataset import KittiMonoDataset
from .lyft_dataset import LyftDataset
from .nuscenes_dataset import NuScenesDataset
from .nuscenes_mono_dataset import NuScenesMonoDataset
from .pipelines import (BackgroundPointsFilter, GlobalRotScaleTrans,
from .pipelines import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, LoadAnnotations3D,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
......@@ -26,10 +27,10 @@ __all__ = [
'DATASETS', 'build_dataset', 'CocoDataset', 'NuScenesDataset',
'NuScenesMonoDataset', 'LyftDataset', 'ObjectSample', 'RandomFlip3D',
'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle', 'ObjectRangeFilter',
'PointsRangeFilter', 'Collect3D', 'LoadPointsFromFile',
'NormalizePointsColor', 'IndoorPointSample', 'LoadAnnotations3D',
'SUNRGBDDataset', 'ScanNetDataset', 'ScanNetSegDataset', 'S3DISSegDataset',
'SemanticKITTIDataset', 'Custom3DDataset', 'Custom3DSegDataset',
'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter',
'VoxelBasedPointSampler', 'get_loading_pipeline'
'PointsRangeFilter', 'Collect3D', 'LoadPointsFromFile', 'S3DISSegDataset',
'NormalizePointsColor', 'IndoorPatchPointSample', 'IndoorPointSample',
'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset',
'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset',
'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline'
]
......@@ -6,11 +6,11 @@ from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D,
LoadPointsFromMultiSweeps, NormalizePointsColor,
PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (BackgroundPointsFilter, GlobalRotScaleTrans,
IndoorPatchPointSample, IndoorPointSample,
ObjectNoise, ObjectRangeFilter, ObjectSample,
PointShuffle, PointsRangeFilter, RandomFlip3D,
VoxelBasedPointSampler)
from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler)
__all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
......@@ -19,6 +19,6 @@ __all__ = [
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment',
'IndoorPatchPointSample', 'LoadImageFromFileMono3D'
]
......@@ -293,6 +293,93 @@ class ObjectNoise(object):
return repr_str
@PIPELINES.register_module()
class GlobalAlignment(object):
"""Apply global alignment to 3D scene points by rotation and translation.
Args:
rotation_axis (int): Rotation axis for points and bboxes rotation.
Note:
We do not record the applied rotation and translation as in \
GlobalRotScaleTrans. Because usually, we do not need to reverse \
the alignment step.
For example, ScanNet 3D detection task uses aligned ground-truth \
bounding boxes for evaluation.
"""
def __init__(self, rotation_axis):
self.rotation_axis = rotation_axis
def _trans_points(self, input_dict, trans_factor):
"""Private function to translate points.
Args:
input_dict (dict): Result dict from loading pipeline.
trans_factor (np.ndarray): Translation vector to be applied.
Returns:
dict: Results after translation, 'points' is updated in the dict.
"""
input_dict['points'].translate(trans_factor)
def _rot_points(self, input_dict, rot_mat):
"""Private function to rotate bounding boxes and points.
Args:
input_dict (dict): Result dict from loading pipeline.
rot_mat (np.ndarray): Rotation matrix to be applied.
Returns:
dict: Results after rotation, 'points' is updated in the dict.
"""
# input should be rot_mat_T so I transpose it here
input_dict['points'].rotate(rot_mat.T)
def _check_rot_mat(self, rot_mat):
"""Check if rotation matrix is valid for self.rotation_axis.
Args:
rot_mat (np.ndarray): Rotation matrix to be applied.
"""
is_valid = np.allclose(np.linalg.det(rot_mat), 1.0)
valid_array = np.zeros(3)
valid_array[self.rotation_axis] = 1.0
is_valid &= (rot_mat[self.rotation_axis, :] == valid_array).all()
is_valid &= (rot_mat[:, self.rotation_axis] == valid_array).all()
assert is_valid, f'invalid rotation matrix {rot_mat}'
def __call__(self, input_dict):
"""Call function to shuffle points.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after global alignment, 'points' and keys in \
input_dict['bbox3d_fields'] are updated in the result dict.
"""
assert 'axis_align_matrix' in input_dict['ann_info'].keys(), \
'axis_align_matrix is not provided in GlobalAlignment'
axis_align_matrix = input_dict['ann_info']['axis_align_matrix']
assert axis_align_matrix.shape == (4, 4), \
f'invalid shape {axis_align_matrix.shape} for axis_align_matrix'
rot_mat = axis_align_matrix[:3, :3]
trans_vec = axis_align_matrix[:3, -1]
self._check_rot_mat(rot_mat)
self._rot_points(input_dict, rot_mat)
self._trans_points(input_dict, trans_vec)
return input_dict
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(rotation_axis={self.rotation_axis})'
return repr_str
@PIPELINES.register_module()
class GlobalRotScaleTrans(object):
"""Apply global rotation, scaling and translation to a 3D scene.
......
import numpy as np
import tempfile
import warnings
from os import path as osp
from mmdet3d.core import show_result, show_seg_result
......@@ -79,6 +80,8 @@ class ScanNetDataset(Custom3DDataset):
- gt_labels_3d (np.ndarray): Labels of ground truths.
- pts_instance_mask_path (str): Path of instance masks.
- pts_semantic_mask_path (str): Path of semantic masks.
- axis_align_matrix (np.ndarray): Transformation matrix for \
global scene alignment.
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]
......@@ -102,13 +105,55 @@ class ScanNetDataset(Custom3DDataset):
pts_semantic_mask_path = osp.join(self.data_root,
info['pts_semantic_mask_path'])
axis_align_matrix = self._get_axis_align_matrix(info)
anns_results = dict(
gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d,
pts_instance_mask_path=pts_instance_mask_path,
pts_semantic_mask_path=pts_semantic_mask_path)
pts_semantic_mask_path=pts_semantic_mask_path,
axis_align_matrix=axis_align_matrix)
return anns_results
def prepare_test_data(self, index):
"""Prepare data for testing.
We should take axis_align_matrix from self.data_infos since we need \
to align point clouds.
Args:
index (int): Index for accessing the target data.
Returns:
dict: Testing data dict of the corresponding index.
"""
input_dict = self.get_data_info(index)
# take the axis_align_matrix from data_infos
input_dict['ann_info'] = dict(
axis_align_matrix=self._get_axis_align_matrix(
self.data_infos[index]))
self.pre_pipeline(input_dict)
example = self.pipeline(input_dict)
return example
@staticmethod
def _get_axis_align_matrix(info):
"""Get axis_align_matrix from info. If not exist, return identity mat.
Args:
info (dict): one data info term.
Returns:
np.ndarray: 4x4 transformation matrix.
"""
if 'axis_align_matrix' in info['annos'].keys():
return info['annos']['axis_align_matrix'].astype(np.float32)
else:
warnings.warn(
'axis_align_matrix is not found in ScanNet data info, please '
'use new pre-process scripts to re-generate ScanNet data')
return np.eye(4).astype(np.float32)
def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
pipeline = [
......@@ -118,6 +163,7 @@ class ScanNetDataset(Custom3DDataset):
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=self.CLASSES,
......
No preview for this file type
......@@ -27,6 +27,11 @@ def test_getitem():
with_label_3d=True,
with_mask_3d=True,
with_seg_3d=True),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33,
34, 36, 39)),
dict(type='IndoorPointSample', num_points=5),
dict(
type='RandomFlip3D',
......@@ -63,11 +68,12 @@ def test_getitem():
assert file_name == './tests/data/scannet/points/scene0000_00.bin'
assert np.allclose(pcd_rotation, expected_rotation, 1e-3)
assert sample_idx == 'scene0000_00'
expected_points = torch.tensor([[-2.7231, -2.2068, 2.3543, 2.3895],
[-0.4065, -3.4857, 2.1330, 2.1682],
[-1.4578, 1.3510, -0.0441, -0.0089],
[2.2428, -1.1323, -0.0288, 0.0064],
[0.7052, -2.9752, 1.5560, 1.5912]])
expected_points = torch.tensor(
[[1.8339e+00, 2.1093e+00, 2.2900e+00, 2.3895e+00],
[3.6079e+00, 1.4592e-01, 2.0687e+00, 2.1682e+00],
[4.1886e+00, 5.0614e+00, -1.0841e-01, -8.8736e-03],
[6.8790e+00, 1.5086e+00, -9.3154e-02, 6.3816e-03],
[4.8253e+00, 2.6668e-01, 1.4917e+00, 1.5912e+00]])
expected_gt_bboxes_3d = torch.tensor(
[[-1.1835, -3.6317, 1.5704, 1.7577, 0.3761, 0.5724, 0.0000],
[-3.1832, 3.2269, 1.1911, 0.6727, 0.2251, 0.6715, 0.0000],
......@@ -78,7 +84,7 @@ def test_getitem():
6, 6, 4, 9, 11, 11, 10, 0, 15, 17, 17, 17, 3, 12, 4, 4, 14, 1, 0, 0, 0,
0, 0, 0, 5, 5, 5
])
expected_pts_semantic_mask = np.array([3, 1, 2, 2, 15])
expected_pts_semantic_mask = np.array([0, 18, 18, 18, 18])
expected_pts_instance_mask = np.array([44, 22, 10, 10, 57])
original_classes = scannet_dataset.CLASSES
......@@ -165,6 +171,32 @@ def test_evaluate():
assert abs(ret_dict['counter_AP_0.25'] - 1.0) < 0.01
assert abs(ret_dict['curtain_AP_0.25'] - 1.0) < 0.01
# test evaluate with pipeline
class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
'window', 'bookshelf', 'picture', 'counter', 'desk',
'curtain', 'refrigerator', 'showercurtrain', 'toilet',
'sink', 'bathtub', 'garbagebin')
eval_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
]
ret_dict = scannet_dataset.evaluate(
results, metric, pipeline=eval_pipeline)
assert abs(ret_dict['table_AP_0.25'] - 0.3333) < 0.01
assert abs(ret_dict['window_AP_0.25'] - 1.0) < 0.01
assert abs(ret_dict['counter_AP_0.25'] - 1.0) < 0.01
assert abs(ret_dict['curtain_AP_0.25'] - 1.0) < 0.01
def test_show():
import mmcv
......@@ -226,6 +258,7 @@ def test_show():
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
......
......@@ -5,9 +5,10 @@ import torch
from mmdet3d.core import Box3DMode, CameraInstance3DBoxes, LiDARInstance3DBoxes
from mmdet3d.core.points import DepthPoints, LiDARPoints
from mmdet3d.datasets import (BackgroundPointsFilter, ObjectNoise,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler)
from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment,
ObjectNoise, ObjectSample, PointShuffle,
PointsRangeFilter, RandomFlip3D,
VoxelBasedPointSampler)
def test_remove_points_in_boxes():
......@@ -221,6 +222,39 @@ def test_points_range_filter():
assert repr_str == expected_repr_str
def test_global_alignment():
np.random.seed(0)
global_alignment = GlobalAlignment(rotation_axis=2)
points = np.fromfile('tests/data/scannet/points/scene0000_00.bin',
np.float32).reshape(-1, 6)
annos = mmcv.load('tests/data/scannet/scannet_infos.pkl')
info = annos[0]
axis_align_matrix = info['annos']['axis_align_matrix']
depth_points = DepthPoints(points.copy(), points_dim=6)
input_dict = dict(
points=depth_points.clone(),
ann_info=dict(axis_align_matrix=axis_align_matrix))
input_dict = global_alignment(input_dict)
trans_depth_points = input_dict['points']
# construct expected transformed points by affine transformation
pts = np.ones((points.shape[0], 4))
pts[:, :3] = points[:, :3]
trans_pts = np.dot(pts, axis_align_matrix.T)
expected_points = np.concatenate([trans_pts[:, :3], points[:, 3:]], axis=1)
assert np.allclose(
trans_depth_points.tensor.numpy(), expected_points, atol=1e-6)
repr_str = repr(global_alignment)
expected_repr_str = 'GlobalAlignment(rotation_axis=2)'
assert repr_str == expected_repr_str
def test_random_flip_3d():
random_flip_3d = RandomFlip3D(
flip_ratio_bev_horizontal=1.0, flip_ratio_bev_vertical=1.0)
......
......@@ -27,6 +27,11 @@ def test_scannet_pipeline():
with_label_3d=True,
with_mask_3d=True,
with_seg_3d=True),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33,
34, 36, 39)),
dict(type='IndoorPointSample', num_points=5),
dict(
type='RandomFlip3D',
......@@ -66,6 +71,8 @@ def test_scannet_pipeline():
results['ann_info']['gt_bboxes_3d'] = DepthInstance3DBoxes(
scannet_gt_bboxes_3d, box_dim=6, with_yaw=False)
results['ann_info']['gt_labels_3d'] = scannet_gt_labels_3d
results['ann_info']['axis_align_matrix'] = \
info['annos']['axis_align_matrix']
results['img_fields'] = []
results['bbox3d_fields'] = []
......@@ -79,11 +86,12 @@ def test_scannet_pipeline():
gt_labels_3d = results['gt_labels_3d']._data
pts_semantic_mask = results['pts_semantic_mask']._data
pts_instance_mask = results['pts_instance_mask']._data
expected_points = torch.tensor([[-2.7231, -2.2068, 2.3543, 2.3895],
[-0.4065, -3.4857, 2.1330, 2.1682],
[-1.4578, 1.3510, -0.0441, -0.0089],
[2.2428, -1.1323, -0.0288, 0.0064],
[0.7052, -2.9752, 1.5560, 1.5912]])
expected_points = torch.tensor(
[[1.8339e+00, 2.1093e+00, 2.2900e+00, 2.3895e+00],
[3.6079e+00, 1.4592e-01, 2.0687e+00, 2.1682e+00],
[4.1886e+00, 5.0614e+00, -1.0841e-01, -8.8736e-03],
[6.8790e+00, 1.5086e+00, -9.3154e-02, 6.3816e-03],
[4.8253e+00, 2.6668e-01, 1.4917e+00, 1.5912e+00]])
expected_gt_bboxes_3d = torch.tensor(
[[-1.1835, -3.6317, 1.8565, 1.7577, 0.3761, 0.5724, 0.0000],
[-3.1832, 3.2269, 1.5268, 0.6727, 0.2251, 0.6715, 0.0000],
......@@ -94,7 +102,7 @@ def test_scannet_pipeline():
6, 6, 4, 9, 11, 11, 10, 0, 15, 17, 17, 17, 3, 12, 4, 4, 14, 1, 0, 0, 0,
0, 0, 0, 5, 5, 5
])
expected_pts_semantic_mask = np.array([3, 1, 2, 2, 15])
expected_pts_semantic_mask = np.array([0, 18, 18, 18, 18])
expected_pts_instance_mask = np.array([44, 22, 10, 10, 57])
assert torch.allclose(points, expected_points, 1e-2)
assert torch.allclose(gt_bboxes_3d.tensor[:5, :], expected_gt_bboxes_3d,
......
......@@ -11,6 +11,7 @@ from mmdet3d.core.bbox.structures.utils import (get_box_type, limit_period,
points_cam2img,
rotation_3d_in_axis,
xywhr2xyxyr)
from mmdet3d.core.points import CameraPoints, DepthPoints, LiDARPoints
def test_bbox3d_mapping_back():
......@@ -225,6 +226,7 @@ def test_lidar_boxes3d():
assert torch.allclose(points, expected_points)
# test box rotation
# with input torch.Tensor points and angle
expected_tensor = torch.tensor(
[[1.4225, -2.7344, -1.7501, 1.7500, 3.3900, 1.6500, 1.7976],
[8.5435, -3.6491, -1.6357, 1.5400, 4.0100, 1.5700, 1.6576],
......@@ -244,6 +246,16 @@ def test_lidar_boxes3d():
assert torch.allclose(points, expected_points, 1e-3)
assert torch.allclose(rot_mat_T, expected_rot_mat_T, 1e-3)
# with input torch.Tensor points and rotation matrix
points, rot_mat_T = boxes.rotate(-0.13603681398218053, points) # back
rot_mat = np.array([[0.99076125, -0.13561762, 0.],
[0.13561762, 0.99076125, 0.], [0., 0., 1.]])
points, rot_mat_T = boxes.rotate(rot_mat, points)
assert torch.allclose(boxes.tensor, expected_tensor, 1e-3)
assert torch.allclose(points, expected_points, 1e-3)
assert torch.allclose(rot_mat_T, expected_rot_mat_T, 1e-3)
# with input np.ndarray points and angle
points_np = np.array([[-1.0280, 0.9888,
-1.4658], [-4.3695, 2.1310, -1.3857],
[-6.5263, 1.5595,
......@@ -262,6 +274,15 @@ def test_lidar_boxes3d():
assert np.allclose(points_np, expected_points_np, 1e-3)
assert np.allclose(rot_mat_T_np, expected_rot_mat_T_np, 1e-3)
# with input LiDARPoints and rotation matrix
points_np, rot_mat_T_np = boxes.rotate(-0.13603681398218053, points_np)
lidar_points = LiDARPoints(points_np)
lidar_points, rot_mat_T_np = boxes.rotate(rot_mat, lidar_points)
points_np = lidar_points.tensor.numpy()
assert np.allclose(points_np, expected_points_np, 1e-3)
assert np.allclose(rot_mat_T_np, expected_rot_mat_T_np, 1e-3)
# test box scaling
expected_tensor = torch.tensor([[
1.0443488, -2.9183323, -1.7599131, 1.7597977, 3.4089797, 1.6592377,
......@@ -701,6 +722,7 @@ def test_camera_boxes3d():
assert torch.allclose(points, expected_points)
# test box rotation
# with input torch.Tensor points and angle
expected_tensor = Box3DMode.convert(
torch.tensor(
[[1.4225, -2.7344, -1.7501, 1.7500, 3.3900, 1.6500, 1.7976],
......@@ -722,6 +744,17 @@ def test_camera_boxes3d():
assert torch.allclose(points, expected_points, 1e-3)
assert torch.allclose(rot_mat_T, expected_rot_mat_T, 1e-3)
# with input torch.Tensor points and rotation matrix
points, rot_mat_T = boxes.rotate(
torch.tensor(-0.13603681398218053), points) # back
rot_mat = np.array([[0.99076125, 0., -0.13561762], [0., 1., 0.],
[0.13561762, 0., 0.99076125]])
points, rot_mat_T = boxes.rotate(rot_mat, points)
assert torch.allclose(boxes.tensor, expected_tensor, 1e-3)
assert torch.allclose(points, expected_points, 1e-3)
assert torch.allclose(rot_mat_T, expected_rot_mat_T, 1e-3)
# with input np.ndarray points and angle
points_np = np.array([[0.6762, 1.2559, -1.4658, 2.5359],
[0.8784, 4.7814, -1.3857, 0.7167],
[-0.2517, 6.7053, -0.9697, 0.5599],
......@@ -741,6 +774,15 @@ def test_camera_boxes3d():
assert np.allclose(points_np, expected_points_np, 1e-3)
assert np.allclose(rot_mat_T_np, expected_rot_mat_T_np, 1e-3)
# with input CameraPoints and rotation matrix
points_np, rot_mat_T_np = boxes.rotate(
torch.tensor(-0.13603681398218053), points_np)
camera_points = CameraPoints(points_np, points_dim=4)
camera_points, rot_mat_T_np = boxes.rotate(rot_mat, camera_points)
points_np = camera_points.tensor.numpy()
assert np.allclose(points_np, expected_points_np, 1e-3)
assert np.allclose(rot_mat_T_np, expected_rot_mat_T_np, 1e-3)
# test box scaling
expected_tensor = Box3DMode.convert(
torch.tensor([[
......@@ -1007,7 +1049,7 @@ def test_depth_boxes3d():
# test box concatenation
expected_tensor = torch.tensor(
[[1.4856, 2.5299, -0.5570, 0.9385, 2.1404, 0.8954, 3.0601],
[2.3262, 3.3065, --0.44255, 0.8234, 0.5325, 1.0099, 2.9971],
[2.3262, 3.3065, 0.44255, 0.8234, 0.5325, 1.0099, 2.9971],
[2.4593, 2.5870, -0.4321, 0.8597, 0.6193, 1.0204, 3.0693],
[1.4856, 2.5299, -0.5570, 0.9385, 2.1404, 0.8954, 3.0601]])
boxes = DepthInstance3DBoxes.cat([boxes_1, boxes_2])
......@@ -1049,14 +1091,16 @@ def test_depth_boxes3d():
[0.5358, -4.5870, -1.4741, 0.0556]])
assert torch.allclose(boxes.tensor, expected_tensor, 1e-3)
assert torch.allclose(points, expected_points)
# test box rotation
# with input torch.Tensor points and angle
boxes_rot = boxes.clone()
expected_tensor = torch.tensor(
[[-1.5434, -2.4951, -0.5570, 0.9385, 2.1404, 0.8954, -0.0585],
[-2.4016, -3.2521, 0.4426, 0.8234, 0.5325, 1.0099, -0.1215],
[-2.5181, -2.5298, -0.4321, 0.8597, 0.6193, 1.0204, -0.0493],
[-1.5434, -2.4951, -0.5570, 0.9385, 2.1404, 0.8954, -0.0585]])
points, rot_mar_T = boxes_rot.rotate(-0.022998953275003075, points)
points, rot_mat_T = boxes_rot.rotate(-0.022998953275003075, points)
expected_points = torch.tensor([[-0.7049, -1.2400, -1.4658, 2.5359],
[-0.9881, -4.7599, -1.3857, 0.7167],
[0.0974, -6.7093, -0.9697, 0.5599],
......@@ -1067,14 +1111,24 @@ def test_depth_boxes3d():
[0.0000, 0.0000, 1.0000]])
assert torch.allclose(boxes_rot.tensor, expected_tensor, 1e-3)
assert torch.allclose(points, expected_points, 1e-3)
assert torch.allclose(rot_mar_T, expected_rot_mat_T, 1e-3)
assert torch.allclose(rot_mat_T, expected_rot_mat_T, 1e-3)
# with input torch.Tensor points and rotation matrix
points, rot_mat_T = boxes.rotate(0.022998953275003075, points) # back
rot_mat = np.array([[0.99973554, 0.02299693, 0.],
[-0.02299693, 0.99973554, 0.], [0., 0., 1.]])
points, rot_mat_T = boxes.rotate(rot_mat, points)
assert torch.allclose(boxes_rot.tensor, expected_tensor, 1e-3)
assert torch.allclose(points, expected_points, 1e-3)
assert torch.allclose(rot_mat_T, expected_rot_mat_T, 1e-3)
# with input np.ndarray points and angle
points_np = np.array([[0.6762, 1.2559, -1.4658, 2.5359],
[0.8784, 4.7814, -1.3857, 0.7167],
[-0.2517, 6.7053, -0.9697, 0.5599],
[0.5520, 0.6533, -0.5265, 1.0032],
[-0.5358, 4.5870, -1.4741, 0.0556]])
points_np, rot_mar_T_np = boxes.rotate(-0.022998953275003075, points_np)
points_np, rot_mat_T_np = boxes.rotate(-0.022998953275003075, points_np)
expected_points_np = np.array([[0.7049, 1.2400, -1.4658, 2.5359],
[0.9881, 4.7599, -1.3857, 0.7167],
[-0.0974, 6.7093, -0.9697, 0.5599],
......@@ -1090,7 +1144,17 @@ def test_depth_boxes3d():
[-1.5434, -2.4951, -0.5570, 0.9385, 2.1404, 0.8954, -0.0585]])
assert torch.allclose(boxes.tensor, expected_tensor, 1e-3)
assert np.allclose(points_np, expected_points_np, 1e-3)
assert np.allclose(rot_mar_T_np, expected_rot_mat_T_np, 1e-3)
assert np.allclose(rot_mat_T_np, expected_rot_mat_T_np, 1e-3)
# with input DepthPoints and rotation matrix
points_np, rot_mat_T_np = boxes.rotate(0.022998953275003075, points_np)
depth_points = DepthPoints(points_np, points_dim=4)
depth_points, rot_mat_T_np = boxes.rotate(rot_mat, depth_points)
points_np = depth_points.tensor.numpy()
assert torch.allclose(boxes.tensor, expected_tensor, 1e-3)
assert np.allclose(points_np, expected_points_np, 1e-3)
assert np.allclose(rot_mat_T_np, expected_rot_mat_T_np, 1e-3)
th_boxes = torch.tensor(
[[0.61211395, 0.8129094, 0.10563634, 1.497534, 0.16927195, 0.27956772],
[1.430009, 0.49797538, 0.9382923, 0.07694054, 0.9312509, 1.8919173]],
......
......@@ -42,12 +42,24 @@ class ScanNetData(object):
def __len__(self):
return len(self.sample_id_list)
def get_box_label(self, idx):
def get_aligned_box_label(self, idx):
box_file = osp.join(self.root_dir, 'scannet_instance_data',
f'{idx}_bbox.npy')
f'{idx}_aligned_bbox.npy')
mmcv.check_file_exist(box_file)
return np.load(box_file)
def get_unaligned_box_label(self, idx):
box_file = osp.join(self.root_dir, 'scannet_instance_data',
f'{idx}_unaligned_bbox.npy')
mmcv.check_file_exist(box_file)
return np.load(box_file)
def get_axis_align_matrix(self, idx):
matrix_file = osp.join(self.root_dir, 'scannet_instance_data',
f'{idx}_axis_align_matrix.npy')
mmcv.check_file_exist(matrix_file)
return np.load(matrix_file)
def get_infos(self, num_workers=4, has_label=True, sample_id_list=None):
"""Get data infos.
......@@ -106,25 +118,35 @@ class ScanNetData(object):
if has_label:
annotations = {}
boxes_with_classes = self.get_box_label(
sample_idx) # k, 6 + class
annotations['gt_num'] = boxes_with_classes.shape[0]
# box is of shape [k, 6 + class]
aligned_box_label = self.get_aligned_box_label(sample_idx)
unaligned_box_label = self.get_unaligned_box_label(sample_idx)
annotations['gt_num'] = aligned_box_label.shape[0]
if annotations['gt_num'] != 0:
minmax_boxes3d = boxes_with_classes[:, :-1] # k, 6
classes = boxes_with_classes[:, -1] # k, 1
aligned_box = aligned_box_label[:, :-1] # k, 6
unaligned_box = unaligned_box_label[:, :-1]
classes = aligned_box_label[:, -1] # k
annotations['name'] = np.array([
self.label2cat[self.cat_ids2class[classes[i]]]
for i in range(annotations['gt_num'])
])
annotations['location'] = minmax_boxes3d[:, :3]
annotations['dimensions'] = minmax_boxes3d[:, 3:6]
annotations['gt_boxes_upright_depth'] = minmax_boxes3d
# default names are given to aligned bbox for compatibility
# we also save unaligned bbox info with marked names
annotations['location'] = aligned_box[:, :3]
annotations['dimensions'] = aligned_box[:, 3:6]
annotations['gt_boxes_upright_depth'] = aligned_box
annotations['unaligned_location'] = unaligned_box[:, :3]
annotations['unaligned_dimensions'] = unaligned_box[:, 3:6]
annotations[
'unaligned_gt_boxes_upright_depth'] = unaligned_box
annotations['index'] = np.arange(
annotations['gt_num'], dtype=np.int32)
annotations['class'] = np.array([
self.cat_ids2class[classes[i]]
for i in range(annotations['gt_num'])
])
axis_align_matrix = self.get_axis_align_matrix(sample_idx)
annotations['axis_align_matrix'] = axis_align_matrix # 4x4
info['annos'] = annotations
return info
......@@ -197,9 +219,6 @@ class ScanNetSegData(object):
mask = np.load(mask)
else:
mask = np.fromfile(mask, dtype=np.long)
# first filter out unannotated points (labeled as 0)
mask = mask[mask != 0]
# then convert to [0, 20) labels
label = self.cat_id2class[mask]
return label
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment