Skip to content
Snippets Groups Projects
Unverified Commit 3870001a authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Support entire PAConv and PAConvCUDA models (#783)

* add PAConv decode head

* add config files

* add paconv's correlation loss

* support reg loss in Segmentor class

* minor fix

* add augmentation to configs

* fix ed7 in cfg

* fix bug in corr loss

* enable syncbn in paconv

* rename to loss_regularization

* rename loss_reg to loss_regularize

* use SyncBN

* change weight kernels to kernel weights

* rename corr_loss to reg_loss

* minor fix

* configs fix IndoorPatchPointSample

* fix grouped points minus center error

* update transform_3d & add configs

* merge master

* fix enlarge_size bug

* refine config

* remove cfg files

* minor fix

* add comments on PAConv's ScoreNet

* refine comments

* update compatibility doc

* remove useless lines in transforms_3d

* rename with_loss_regularization to with_regularization_loss

* revert palette change

* remove xavier init from PAConv's ScoreNet
parent a8f47523
Branches
No related tags found
No related merge requests found
Showing
with 659 additions and 81 deletions
_base_ = './paconv_ssg.py'
model = dict(
backbone=dict(
sa_cfg=dict(
type='PAConvCUDASAModule',
scorenet_cfg=dict(mlp_channels=[8, 16, 16]))))
# model settings
model = dict(
type='EncoderDecoder3D',
backbone=dict(
type='PointNet2SASSG',
in_channels=9, # [xyz, rgb, normalized_xyz]
num_points=(1024, 256, 64, 16),
radius=(None, None, None, None), # use kNN instead of ball query
num_samples=(32, 32, 32, 32),
sa_channels=((32, 32, 64), (64, 64, 128), (128, 128, 256), (256, 256,
512)),
fp_channels=(),
norm_cfg=dict(type='BN2d', momentum=0.1),
sa_cfg=dict(
type='PAConvSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=False,
paconv_num_kernels=[16, 16, 16],
paconv_kernel_input='w_neighbor',
scorenet_input='w_neighbor_dist',
scorenet_cfg=dict(
mlp_channels=[16, 16, 16],
score_norm='softmax',
temp_factor=1.0,
last_bn=False))),
decode_head=dict(
type='PAConvHead',
# PAConv model's decoder takes skip connections from beckbone
# different from PointNet++, it also concats input features in the last
# level of decoder, leading to `128 + 6` as the channel number
fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128),
(128 + 6, 128, 128, 128)),
channels=128,
dropout_ratio=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=None, # should be modified with dataset
loss_weight=1.0)),
# correlation loss to regularize PAConv's kernel weights
loss_regularization=dict(
type='PAConvRegularizationLoss', reduction='sum', loss_weight=10.0),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide'))
......@@ -4,6 +4,12 @@ This document provides detailed descriptions of the BC-breaking changes in MMDet
## MMDetection3D 0.16.0
### Returned values of `QueryAndGroup` operation
We modified the returned `grouped_xyz` value of operation `QueryAndGroup` to support PAConv segmentor. Originally, the `grouped_xyz` is centered by subtracting the grouping centers, which represents the relative positions of grouped points. Now, we didn't perform such subtraction and the returned `grouped_xyz` stands for the absolute coordinates of these points.
Note that, the other returned variables of `QueryAndGroup` such as `new_features`, `unique_cnt` and `grouped_idx` are not affected.
### NuScenes coco-style data pre-processing
We remove the rotation and dimension hack in the monocular 3D detection on nuScenes. Specifically, we transform the rotation and dimension of boxes defined by nuScenes devkit to the coordinate system of our `CameraInstance3DBoxes` in the pre-processing and transform them back in the post-processing. In this way, we can remove the corresponding [hack](https://github.com/open-mmlab/mmdetection3d/pull/744/files#diff-5bee5062bd84e6fa25a2fdd71353f6f283dfdc4a66a0316c3b1ca26078c978b6L165) used in the visualization tools. The modification also guarantees the correctness of all the operations based on our `CameraInstance3DBoxes` (such as NMS and flip augmentation) when training monocular 3D detectors.
......
......@@ -43,7 +43,7 @@ def single_gpu_test(model,
models_3d = (Base3DDetector, Base3DSegmentor,
SingleStageMono3DDetector)
if isinstance(model.module, models_3d):
model.module.show_results(data, result, out_dir)
model.module.show_results(data, result, out_dir=out_dir)
# Visualize the results of MMDetection model
# 'show_result' is MMdetection visualization API
else:
......
......@@ -928,6 +928,7 @@ class IndoorPatchPointSample(object):
Defaults to None.
ignore_index (int, optional): Label index that won't be used for the
segmentation task. This is set in PointSegClassMapping as neg_cls.
If not None, will be used as a patch selection criterion.
Defaults to None.
use_normalized_coord (bool, optional): Whether to use normalized xyz as
additional features. Defaults to False.
......@@ -935,10 +936,12 @@ class IndoorPatchPointSample(object):
is invalid. Defaults to 10.
enlarge_size (float | None, optional): Enlarge the sampled patch to
[-block_size / 2 - enlarge_size, block_size / 2 + enlarge_size] as
an augmentation. If None, set it as 0.01. Defaults to 0.2.
an augmentation. If None, set it as 0. Defaults to 0.2.
min_unique_num (int | None, optional): Minimum number of unique points
the sampled patch should contain. If None, use PointNet++'s method
to judge uniqueness. Defaults to None.
eps (float, optional): A value added to patch boundary to guarantee
points coverage. Defaults to 1e-2.
Note:
This transform should only be used in the training process of point
......@@ -955,14 +958,16 @@ class IndoorPatchPointSample(object):
use_normalized_coord=False,
num_try=10,
enlarge_size=0.2,
min_unique_num=None):
min_unique_num=None,
eps=1e-2):
self.num_points = num_points
self.block_size = block_size
self.ignore_index = ignore_index
self.use_normalized_coord = use_normalized_coord
self.num_try = num_try
self.enlarge_size = enlarge_size if enlarge_size is not None else 0.01
self.enlarge_size = enlarge_size if enlarge_size is not None else 0.0
self.min_unique_num = min_unique_num
self.eps = eps
if sample_rate is not None:
warnings.warn(
......@@ -1010,7 +1015,7 @@ class IndoorPatchPointSample(object):
return points
def _patch_points_sampling(self, points, sem_mask, replace=None):
def _patch_points_sampling(self, points, sem_mask):
"""Patch points sampling.
First sample a valid patch.
......@@ -1019,8 +1024,6 @@ class IndoorPatchPointSample(object):
Args:
points (:obj:`BasePoints`): 3D Points.
sem_mask (np.ndarray): semantic segmentation mask for input points.
replace (bool): Whether the sample is with or without replacement.
Defaults to None.
Returns:
tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`:
......@@ -1040,7 +1043,8 @@ class IndoorPatchPointSample(object):
# random sample a point as patch center
cur_center = coords[np.random.choice(coords.shape[0])]
# boundary of a patch
# boundary of a patch, which would be enlarged by
# `self.enlarge_size` as an augmentation
cur_max = cur_center + np.array(
[self.block_size / 2.0, self.block_size / 2.0, 0.0])
cur_min = cur_center - np.array(
......@@ -1057,14 +1061,14 @@ class IndoorPatchPointSample(object):
cur_coords = coords[cur_choice, :]
cur_sem_mask = sem_mask[cur_choice]
# two criterion for patch sampling, adopted from PointNet++
# points within selected patch shoule be scattered separately
point_idxs = np.where(cur_choice)[0]
mask = np.sum(
(cur_coords >= (cur_min - 0.01)) * (cur_coords <=
(cur_max + 0.01)),
(cur_coords >= (cur_min - self.eps)) * (cur_coords <=
(cur_max + self.eps)),
axis=1) == 3
# two criteria for patch sampling, adopted from PointNet++
# 1. selected patch should contain enough unique points
if self.min_unique_num is None:
# use PointNet++'s method as default
# [31, 31, 62] are just some big values used to transform
......@@ -1077,9 +1081,10 @@ class IndoorPatchPointSample(object):
vidx[:, 2])
flag1 = len(vidx) / 31.0 / 31.0 / 62.0 >= 0.02
else:
# if `min_unique_num` is provided, directly compare with it
flag1 = mask.sum() >= self.min_unique_num
# selected patch should contain enough annotated points
# 2. selected patch should contain enough annotated points
if self.ignore_index is None:
flag2 = True
else:
......@@ -1089,11 +1094,19 @@ class IndoorPatchPointSample(object):
if flag1 and flag2:
break
# random sample idx
if replace is None:
replace = (cur_sem_mask.shape[0] < self.num_points)
# sample idx to `self.num_points`
if point_idxs.size >= self.num_points:
# no duplicate in sub-sampling
choices = np.random.choice(
np.where(cur_choice)[0], self.num_points, replace=replace)
point_idxs, self.num_points, replace=False)
else:
# do not use random choice here to avoid some points not counted
dup = np.random.choice(point_idxs.size,
self.num_points - point_idxs.size)
idx_dup = np.concatenate(
[np.arange(point_idxs.size),
np.array(dup)], 0)
choices = point_idxs[idx_dup]
# construct model input
points = self._input_generation(coords[choices], cur_center, coord_max,
......
from .paconv_head import PAConvHead
from .pointnet2_head import PointNet2Head
__all__ = ['PointNet2Head']
__all__ = ['PointNet2Head', 'PAConvHead']
from mmcv.cnn.bricks import ConvModule
from mmdet.models import HEADS
from .pointnet2_head import PointNet2Head
@HEADS.register_module()
class PAConvHead(PointNet2Head):
r"""PAConv decoder head.
Decoder head used in `PAConv <https://arxiv.org/abs/2103.14635>`_.
Refer to the `official code <https://github.com/CVMI-Lab/PAConv>`_.
Args:
fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules.
fp_norm_cfg (dict|None): Config of norm layers used in FP modules.
Default: dict(type='BN2d').
"""
def __init__(self,
fp_channels=((768, 256, 256), (384, 256, 256),
(320, 256, 128), (128 + 6, 128, 128, 128)),
fp_norm_cfg=dict(type='BN2d'),
**kwargs):
super(PAConvHead, self).__init__(fp_channels, fp_norm_cfg, **kwargs)
# https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/pointnet2/pointnet2_paconv_seg.py#L53
# PointNet++'s decoder conv has bias while PAConv's doesn't have
# so we need to rebuild it here
self.pre_seg_conv = ConvModule(
fp_channels[-1][-1],
self.channels,
kernel_size=1,
bias=False,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, feat_dict):
"""Forward pass.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
torch.Tensor: Segmentation map of shape [B, num_classes, N].
"""
sa_xyz, sa_features = self._extract_input(feat_dict)
# PointNet++ doesn't use the first level of `sa_features` as input
# while PAConv inputs it through skip-connection
fp_feature = sa_features[-1]
for i in range(self.num_fp):
# consume the points in a bottom-up manner
fp_feature = self.FP_modules[i](sa_xyz[-(i + 2)], sa_xyz[-(i + 1)],
sa_features[-(i + 2)], fp_feature)
output = self.pre_seg_conv(fp_feature)
output = self.cls_seg(output)
return output
......@@ -15,18 +15,22 @@ class PointNet2Head(Base3DDecodeHead):
Args:
fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules.
fp_norm_cfg (dict|None): Config of norm layers used in FP modules.
Default: dict(type='BN2d').
"""
def __init__(self,
fp_channels=((768, 256, 256), (384, 256, 256),
(320, 256, 128), (128, 128, 128, 128)),
fp_norm_cfg=dict(type='BN2d'),
**kwargs):
super(PointNet2Head, self).__init__(**kwargs)
self.num_fp = len(fp_channels)
self.FP_modules = nn.ModuleList()
for cur_fp_mlps in fp_channels:
self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))
self.FP_modules.append(
PointFPModule(mlp_channels=cur_fp_mlps, norm_cfg=fp_norm_cfg))
# https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L40
self.pre_seg_conv = ConvModule(
......
from mmdet.models.losses import FocalLoss, SmoothL1Loss, binary_cross_entropy
from .axis_aligned_iou_loss import AxisAlignedIoULoss, axis_aligned_iou_loss
from .chamfer_distance import ChamferDistance, chamfer_distance
from .paconv_regularization_loss import PAConvRegularizationLoss
__all__ = [
'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance',
'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss'
'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss',
'PAConvRegularizationLoss'
]
import torch
from torch import nn as nn
from mmdet3d.ops import PAConv, PAConvCUDA
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weight_reduce_loss
def weight_correlation(conv):
"""Calculate correlations between kernel weights in Conv's weight bank as
regularization loss. The cosine similarity is used as metrics.
Args:
conv (nn.Module): A Conv modules to be regularized.
Currently we only support `PAConv` and `PAConvCUDA`.
Returns:
torch.Tensor: Correlations between each kernel weights in weight bank.
"""
assert isinstance(conv, (PAConv, PAConvCUDA)), \
f'unsupported module type {type(conv)}'
kernels = conv.weight_bank # [C_in, num_kernels * C_out]
in_channels = conv.in_channels
out_channels = conv.out_channels
num_kernels = conv.num_kernels
# [num_kernels, Cin * Cout]
flatten_kernels = kernels.view(in_channels, num_kernels, out_channels).\
permute(1, 0, 2).reshape(num_kernels, -1)
# [num_kernels, num_kernels]
inner_product = torch.matmul(flatten_kernels, flatten_kernels.T)
# [num_kernels, 1]
kernel_norms = torch.sum(flatten_kernels**2, dim=-1, keepdim=True)**0.5
# [num_kernels, num_kernels]
kernel_norms = torch.matmul(kernel_norms, kernel_norms.T)
cosine_sims = inner_product / kernel_norms
# take upper triangular part excluding diagonal since we only compute
# correlation between different kernels once
# the square is to ensure positive loss, refer to:
# https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/tool/train.py#L208
corr = torch.sum(torch.triu(cosine_sims, diagonal=1)**2)
return corr
def paconv_regularization_loss(modules, reduction):
"""Computes correlation loss of PAConv weight kernels as regularization.
Args:
modules (List[nn.Module] | :obj:`generator`):
A list or a python generator of torch.nn.Modules.
reduction (str): Method to reduce losses among PAConv modules.
The valid reduction method are none, sum or mean.
Returns:
torch.Tensor: Correlation loss of kernel weights.
"""
corr_loss = []
for module in modules:
if isinstance(module, (PAConv, PAConvCUDA)):
corr_loss.append(weight_correlation(module))
corr_loss = torch.stack(corr_loss)
# perform reduction
corr_loss = weight_reduce_loss(corr_loss, reduction=reduction)
return corr_loss
@LOSSES.register_module()
class PAConvRegularizationLoss(nn.Module):
"""Calculate correlation loss of kernel weights in PAConv's weight bank.
This is used as a regularization term in PAConv model training.
Args:
reduction (str): Method to reduce losses. The reduction is performed
among all PAConv modules instead of prediction tensors.
The valid reduction method are none, sum or mean.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super(PAConvRegularizationLoss, self).__init__()
assert reduction in ['none', 'sum', 'mean']
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, modules, reduction_override=None, **kwargs):
"""Forward function of loss calculation.
Args:
modules (List[nn.Module] | :obj:`generator`):
A list or a python generator of torch.nn.Modules.
reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None.
Returns:
torch.Tensor: Correlation loss of kernel weights.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
return self.loss_weight * paconv_regularization_loss(
modules, reduction=reduction)
......@@ -16,6 +16,12 @@ class Base3DSegmentor(BaseSegmentor):
data_dict and use a 3D seg specific visualization function.
"""
@property
def with_regularization_loss(self):
"""bool: whether the segmentor has regularization loss for weight"""
return hasattr(self, 'loss_regularization') and \
self.loss_regularization is not None
def forward_test(self, points, img_metas, **kwargs):
"""Calls either simple_test or aug_test depending on the length of
outer list of points. If len(points) == 1, call simple_test. Otherwise
......@@ -108,5 +114,12 @@ class Base3DSegmentor(BaseSegmentor):
pred_sem_mask = result[batch_id]['semantic_mask'].cpu().numpy()
show_seg_result(points, None, pred_sem_mask, out_dir, file_name,
palette, ignore_index)
show_seg_result(
points,
None,
pred_sem_mask,
out_dir,
file_name,
palette,
ignore_index,
show=True)
......@@ -5,7 +5,7 @@ from torch.nn import functional as F
from mmseg.core import add_prefix
from mmseg.models import SEGMENTORS
from ..builder import build_backbone, build_head, build_neck
from ..builder import build_backbone, build_head, build_loss, build_neck
from .base import Base3DSegmentor
......@@ -23,6 +23,7 @@ class EncoderDecoder3D(Base3DSegmentor):
decode_head,
neck=None,
auxiliary_head=None,
loss_regularization=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
......@@ -33,6 +34,7 @@ class EncoderDecoder3D(Base3DSegmentor):
self.neck = build_neck(neck)
self._init_decode_head(decode_head)
self._init_auxiliary_head(auxiliary_head)
self._init_loss_regularization(loss_regularization)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......@@ -54,6 +56,16 @@ class EncoderDecoder3D(Base3DSegmentor):
else:
self.auxiliary_head = build_head(auxiliary_head)
def _init_loss_regularization(self, loss_regularization):
"""Initialize ``loss_regularization``"""
if loss_regularization is not None:
if isinstance(loss_regularization, list):
self.loss_regularization = nn.ModuleList()
for loss_cfg in loss_regularization:
self.loss_regularization.append(build_loss(loss_cfg))
else:
self.loss_regularization = build_loss(loss_regularization)
def extract_feat(self, points):
"""Extract features from points."""
x = self.backbone(points)
......@@ -110,6 +122,21 @@ class EncoderDecoder3D(Base3DSegmentor):
return losses
def _loss_regularization_forward_train(self):
"""Calculate regularization loss for model weight in training."""
losses = dict()
if isinstance(self.loss_regularization, nn.ModuleList):
for idx, regularize_loss in enumerate(self.loss_regularization):
loss_regularize = dict(
loss_regularize=regularize_loss(self.modules()))
losses.update(add_prefix(loss_regularize, f'regularize_{idx}'))
else:
loss_regularize = dict(
loss_regularize=self.loss_regularization(self.modules()))
losses.update(add_prefix(loss_regularize, 'regularize'))
return losses
def forward_dummy(self, points):
"""Dummy forward function."""
seg_logit = self.encode_decode(points, None)
......@@ -145,6 +172,10 @@ class EncoderDecoder3D(Base3DSegmentor):
x, img_metas, pts_semantic_mask_cat)
losses.update(loss_aux)
if self.with_regularization_loss:
loss_regularize = self._loss_regularization_forward_train()
losses.update(loss_regularize)
return losses
@staticmethod
......
......@@ -98,22 +98,23 @@ class QueryAndGroup(nn.Module):
xyz_trans = points_xyz.transpose(1, 2).contiguous()
# (B, 3, npoint, sample_num)
grouped_xyz = grouping_operation(xyz_trans, idx)
grouped_xyz -= center_xyz.transpose(1, 2).unsqueeze(-1)
grouped_xyz_diff = grouped_xyz - \
center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets
if self.normalize_xyz:
grouped_xyz /= self.max_radius
grouped_xyz_diff /= self.max_radius
if features is not None:
grouped_features = grouping_operation(features, idx)
if self.use_xyz:
# (B, C + 3, npoint, sample_num)
new_features = torch.cat([grouped_xyz, grouped_features],
new_features = torch.cat([grouped_xyz_diff, grouped_features],
dim=1)
else:
new_features = grouped_features
else:
assert (self.use_xyz
), 'Cannot have not features and not use xyz as a feature!'
new_features = grouped_xyz
new_features = grouped_xyz_diff
ret = [new_features]
if self.return_grouped_xyz:
......
import copy
import torch
from mmcv.cnn import (ConvModule, build_activation_layer, build_norm_layer,
constant_init, xavier_init)
constant_init)
from torch import nn as nn
from torch.nn import functional as F
......@@ -10,7 +10,7 @@ from .utils import assign_kernel_withoutk, assign_score, calc_euclidian_dist
class ScoreNet(nn.Module):
"""ScoreNet that outputs coefficient scores to assemble weight kernels in
r"""ScoreNet that outputs coefficient scores to assemble kernel weights in
the weight bank according to the relative position of point pairs.
Args:
......@@ -26,6 +26,13 @@ class ScoreNet(nn.Module):
bias (bool | str, optional): If specified as `auto`, it will be decided
by the norm_cfg. Bias will be set as True if `norm_cfg` is None,
otherwise False. Defaults to 'auto'.
Note:
The official code applies xavier_init to all Conv layers in ScoreNet,
see `PAConv <https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg
/model/pointnet2/paconv.py#L105>`_. However in our experiments, we
did not find much difference in applying such xavier initialization
or not. So we neglect this initialization in our implementation.
"""
def __init__(self,
......@@ -70,13 +77,6 @@ class ScoreNet(nn.Module):
act_cfg=None,
bias=bias))
def init_weights(self):
"""Initialize weights of shared MLP layers."""
# refer to https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/pointnet2/paconv.py#L105 # noqa
for m in self.mlps.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m)
def forward(self, xyz_features):
"""Forward.
......@@ -106,14 +106,14 @@ class ScoreNet(nn.Module):
class PAConv(nn.Module):
"""Non-CUDA version of PAConv.
PAConv stores a trainable weight bank containing several weight kernels.
PAConv stores a trainable weight bank containing several kernel weights.
Given input points and features, it computes coefficient scores to assemble
those kernels to form conv kernels, and then runs convolution on the input.
Args:
in_channels (int): Input channels of point features.
out_channels (int): Output channels of point features.
num_kernels (int): Number of weight kernels in the weight bank.
num_kernels (int): Number of kernel weights in the weight bank.
norm_cfg (dict, optional): Type of normalization method.
Defaults to dict(type='BN2d', momentum=0.1).
act_cfg (dict, optional): Type of activation method.
......@@ -124,7 +124,7 @@ class PAConv(nn.Module):
weight_bank_init (str, optional): Init method of weight bank kernels.
Can be 'kaiming' or 'xavier'. Defaults to 'kaiming'.
kernel_input (str, optional): Input features to be multiplied with
weight kernels. Can be 'identity' or 'w_neighbor'.
kernel weights. Can be 'identity' or 'w_neighbor'.
Defaults to 'w_neighbor'.
scorenet_cfg (dict, optional): Config of the ScoreNet module, which
may contain the following keys and values:
......@@ -147,7 +147,7 @@ class PAConv(nn.Module):
weight_bank_init='kaiming',
kernel_input='w_neighbor',
scorenet_cfg=dict(
mlp_channels=[8, 16, 16],
mlp_channels=[16, 16, 16],
score_norm='softmax',
temp_factor=1.0,
last_bn=False)):
......@@ -156,14 +156,15 @@ class PAConv(nn.Module):
# determine weight kernel size according to used features
if kernel_input == 'identity':
# only use grouped_features
self.kernel_mul = 1
kernel_mul = 1
elif kernel_input == 'w_neighbor':
# concat of (grouped_features - center_features, grouped_features)
self.kernel_mul = 2
kernel_mul = 2
else:
raise NotImplementedError(
f'unsupported kernel_input {kernel_input}')
self.kernel_input = kernel_input
in_channels = kernel_mul * in_channels
# determine mlp channels in ScoreNet according to used xyz features
if scorenet_input == 'identity':
......@@ -180,7 +181,7 @@ class PAConv(nn.Module):
f'unsupported scorenet_input {scorenet_input}')
self.scorenet_input = scorenet_input
# construct weight kernels in weight bank
# construct kernel weights in weight bank
# self.weight_bank is of shape [C, num_kernels * out_c]
# where C can be in_c or (2 * in_c)
if weight_bank_init == 'kaiming':
......@@ -191,17 +192,17 @@ class PAConv(nn.Module):
raise NotImplementedError(
f'unsupported weight bank init method {weight_bank_init}')
self.m = num_kernels
self.num_kernels = num_kernels # the parameter `m` in the paper
weight_bank = weight_init(
torch.empty(self.m, in_channels * self.kernel_mul, out_channels))
torch.empty(self.num_kernels, in_channels, out_channels))
weight_bank = weight_bank.permute(1, 0, 2).reshape(
in_channels * self.kernel_mul, self.m * out_channels).contiguous()
in_channels, self.num_kernels * out_channels).contiguous()
self.weight_bank = nn.Parameter(weight_bank, requires_grad=True)
# construct ScoreNet
scorenet_cfg_ = copy.deepcopy(scorenet_cfg)
scorenet_cfg_['mlp_channels'].insert(0, self.scorenet_in_channels)
scorenet_cfg_['mlp_channels'].append(self.m)
scorenet_cfg_['mlp_channels'].append(self.num_kernels)
self.scorenet = ScoreNet(**scorenet_cfg_)
self.bn = build_norm_layer(norm_cfg, out_channels)[1] if \
......@@ -209,13 +210,16 @@ class PAConv(nn.Module):
self.activate = build_activation_layer(act_cfg) if \
act_cfg is not None else None
# set some basic attributes of Conv layers
self.in_channels = in_channels
self.out_channels = out_channels
self.init_weights()
def init_weights(self):
"""Initialize weights of shared MLP layers."""
self.scorenet.init_weights()
"""Initialize weights of shared MLP layers and BN layers."""
if self.bn is not None:
constant_init(self.bn, val=1)
constant_init(self.bn, val=1, bias=0)
def _prepare_scorenet_input(self, points_xyz):
"""Prepare input point pairs features for self.ScoreNet.
......@@ -273,14 +277,15 @@ class PAConv(nn.Module):
# prepare features for between each point and its grouping center
xyz_features = self._prepare_scorenet_input(points_xyz)
# scores to assemble weight kernels
# scores to assemble kernel weights
scores = self.scorenet(xyz_features) # [B, npoint, K, m]
# first compute out features over all kernels
# features is [B, C, npoint, K], weight_bank is [C, m * out_c]
new_features = torch.matmul(
features.permute(0, 2, 3, 1), self.weight_bank).\
view(B, npoint, K, self.m, -1) # [B, npoint, K, m, out_c]
features.permute(0, 2, 3, 1),
self.weight_bank).view(B, npoint, K, self.num_kernels,
-1) # [B, npoint, K, m, out_c]
# then aggregate using scores
new_features = assign_score(scores, new_features)
......@@ -363,13 +368,13 @@ class PAConvCUDA(PAConv):
# prepare features for between each point and its grouping center
xyz_features = self._prepare_scorenet_input(points_xyz)
# scores to assemble weight kernels
# scores to assemble kernel weights
scores = self.scorenet(xyz_features) # [B, npoint, K, m]
# pre-compute features for points and centers separately
# features is [B, in_c, N], weight_bank is [C, m * out_dim]
point_feat, center_feat = assign_kernel_withoutk(
features, self.weight_bank, self.m)
features, self.weight_bank, self.num_kernels)
# aggregate features using custom cuda op
new_features = assign_score_cuda(
......
......@@ -15,10 +15,10 @@ class PAConvSAModuleMSG(BasePointSAModule):
See the `paper <https://arxiv.org/abs/2103.14635>`_ for more details.
Args:
paconv_num_kernels (list[list[int]]): Number of weight kernels in the
paconv_num_kernels (list[list[int]]): Number of kernel weights in the
weight banks of each layer's PAConv.
paconv_kernel_input (str, optional): Input features to be multiplied
with weight kernels. Can be 'identity' or 'w_neighbor'.
with kernel weights. Can be 'identity' or 'w_neighbor'.
Defaults to 'w_neighbor'.
scorenet_input (str, optional): Type of the input to ScoreNet.
Defaults to 'w_neighbor_dist'. Can be the following values:
......@@ -77,7 +77,7 @@ class PAConvSAModuleMSG(BasePointSAModule):
assert len(paconv_num_kernels) == len(mlp_channels)
for i in range(len(mlp_channels)):
assert len(paconv_num_kernels[i]) == len(mlp_channels[i]) - 1, \
'PAConv number of weight kernels wrong'
'PAConv number of kernel weights wrong'
# in PAConv, bias only exists in ScoreNet
scorenet_cfg['bias'] = bias
......@@ -197,7 +197,7 @@ class PAConvCUDASAModuleMSG(BasePointSAModule):
assert len(paconv_num_kernels) == len(mlp_channels)
for i in range(len(mlp_channels)):
assert len(paconv_num_kernels[i]) == len(mlp_channels[i]) - 1, \
'PAConv number of weight kernels wrong'
'PAConv number of kernel weights wrong'
# in PAConv, bias only exists in ScoreNet
scorenet_cfg['bias'] = bias
......
import pytest
import torch
from torch import nn as nn
def test_chamfer_disrance():
......@@ -69,3 +70,41 @@ def test_chamfer_disrance():
or torch.equal(indices1, indices1.new_tensor(expected_inds2)))
assert (indices2 == indices2.new_tensor([[0, 0, 0, 0, 0], [0, 3, 6, 0,
0]])).all()
def test_paconv_regularization_loss():
from mmdet3d.models.losses import PAConvRegularizationLoss
from mmdet3d.ops import PAConv, PAConvCUDA
from mmdet.apis import set_random_seed
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.paconvs = nn.ModuleList()
self.paconvs.append(PAConv(8, 16, 8))
self.paconvs.append(PAConv(8, 16, 8, kernel_input='identity'))
self.paconvs.append(PAConvCUDA(8, 16, 8))
self.conv1 = nn.Conv1d(3, 8, 1)
set_random_seed(0, True)
model = ToyModel()
# reduction shoule be in ['none', 'mean', 'sum']
with pytest.raises(AssertionError):
paconv_corr_loss = PAConvRegularizationLoss(reduction='l2')
paconv_corr_loss = PAConvRegularizationLoss(reduction='mean')
mean_corr_loss = paconv_corr_loss(model.modules())
assert mean_corr_loss >= 0
assert mean_corr_loss.requires_grad
sum_corr_loss = paconv_corr_loss(model.modules(), reduction_override='sum')
assert torch.allclose(sum_corr_loss, mean_corr_loss * 3)
none_corr_loss = paconv_corr_loss(
model.modules(), reduction_override='none')
assert none_corr_loss.shape[0] == 3
assert torch.allclose(none_corr_loss.mean(), mean_corr_loss)
......@@ -37,10 +37,10 @@ def test_paconv_sa_module_msg():
pool_mod='max',
paconv_kernel_input='w_neighbor').cuda()
assert self.mlps[0].layer0.weight_bank.shape[0] == 12 * 2
assert self.mlps[0].layer0.weight_bank.shape[1] == 16 * 4
assert self.mlps[1].layer0.weight_bank.shape[0] == 12 * 2
assert self.mlps[1].layer0.weight_bank.shape[1] == 32 * 8
assert self.mlps[0].layer0.in_channels == 12 * 2
assert self.mlps[0].layer0.out_channels == 16
assert self.mlps[1].layer0.in_channels == 12 * 2
assert self.mlps[1].layer0.out_channels == 32
assert self.mlps[0].layer0.bn.num_features == 16
assert self.mlps[1].layer0.bn.num_features == 32
......@@ -80,10 +80,12 @@ def test_paconv_sa_module_msg():
pool_mod='max',
paconv_kernel_input='identity').cuda()
assert self.mlps[0].layer0.weight_bank.shape[0] == 12 * 1
assert self.mlps[0].layer0.weight_bank.shape[1] == 16 * 4
assert self.mlps[1].layer0.weight_bank.shape[0] == 12 * 1
assert self.mlps[1].layer0.weight_bank.shape[1] == 32 * 8
assert self.mlps[0].layer0.in_channels == 12 * 1
assert self.mlps[0].layer0.out_channels == 16
assert self.mlps[0].layer0.num_kernels == 4
assert self.mlps[1].layer0.in_channels == 12 * 1
assert self.mlps[1].layer0.out_channels == 32
assert self.mlps[1].layer0.num_kernels == 8
xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)
......@@ -116,8 +118,9 @@ def test_paconv_sa_module():
paconv_kernel_input='w_neighbor')
self = build_sa_module(sa_cfg).cuda()
assert self.mlps[0].layer0.weight_bank.shape[0] == 15 * 2
assert self.mlps[0].layer0.weight_bank.shape[1] == 32 * 8
assert self.mlps[0].layer0.in_channels == 15 * 2
assert self.mlps[0].layer0.out_channels == 32
assert self.mlps[0].layer0.num_kernels == 8
xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)
......@@ -145,7 +148,7 @@ def test_paconv_sa_module():
pool_mod='max',
paconv_kernel_input='identity')
self = build_sa_module(sa_cfg).cuda()
assert self.mlps[0].layer0.weight_bank.shape[0] == 15 * 1
assert self.mlps[0].layer0.in_channels == 15 * 1
xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)
......@@ -191,11 +194,13 @@ def test_paconv_cuda_sa_module_msg():
pool_mod='max',
paconv_kernel_input='w_neighbor').cuda()
assert self.mlps[0][0].weight_bank.shape[0] == 12 * 2
assert self.mlps[0][0].weight_bank.shape[1] == 16 * 4
assert self.mlps[1][0].weight_bank.shape[0] == 12 * 2
assert self.mlps[1][0].weight_bank.shape[1] == 32 * 8
assert self.mlps[0][0].in_channels == 12 * 2
assert self.mlps[0][0].out_channels == 16
assert self.mlps[0][0].num_kernels == 4
assert self.mlps[0][0].bn.num_features == 16
assert self.mlps[1][0].in_channels == 12 * 2
assert self.mlps[1][0].out_channels == 32
assert self.mlps[1][0].num_kernels == 8
assert self.mlps[1][0].bn.num_features == 32
assert self.mlps[0][0].scorenet.mlps.layer0.conv.in_channels == 7
......@@ -253,8 +258,9 @@ def test_paconv_cuda_sa_module():
paconv_kernel_input='w_neighbor')
self = build_sa_module(sa_cfg).cuda()
assert self.mlps[0][0].weight_bank.shape[0] == 15 * 2
assert self.mlps[0][0].weight_bank.shape[1] == 32 * 8
assert self.mlps[0][0].in_channels == 15 * 2
assert self.mlps[0][0].out_channels == 32
assert self.mlps[0][0].num_kernels == 8
xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32)
......
......@@ -193,10 +193,13 @@ def test_paconv():
out_channels = 12
npoint = 4
K = 3
num_kernels = 4
points_xyz = torch.randn(B, 3, npoint, K)
features = torch.randn(B, in_channels, npoint, K)
paconv = PAConv(in_channels, out_channels, 4)
paconv = PAConv(in_channels, out_channels, num_kernels)
assert paconv.weight_bank.shape == torch.Size(
[in_channels * 2, out_channels * num_kernels])
with torch.no_grad():
new_features, _ = paconv((features, points_xyz))
......@@ -213,11 +216,14 @@ def test_paconv_cuda():
N = 32
npoint = 4
K = 3
num_kernels = 4
points_xyz = torch.randn(B, 3, npoint, K).float().cuda()
features = torch.randn(B, in_channels, N).float().cuda()
points_idx = torch.randint(0, N, (B, npoint, K)).long().cuda()
paconv = PAConvCUDA(in_channels, out_channels, 4).cuda()
paconv = PAConvCUDA(in_channels, out_channels, num_kernels).cuda()
assert paconv.weight_bank.shape == torch.Size(
[in_channels * 2, out_channels * num_kernels])
with torch.no_grad():
new_features, _, _ = paconv((features, points_xyz, points_idx))
......
import numpy as np
import pytest
import torch
from mmcv.cnn.bricks import ConvModule
from mmdet3d.models.builder import build_head
def test_paconv_decode_head_loss():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
paconv_decode_head_cfg = dict(
type='PAConvHead',
fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128),
(128 + 6, 128, 128, 128)),
channels=128,
num_classes=20,
dropout_ratio=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
ignore_index=20)
self = build_head(paconv_decode_head_cfg)
self.cuda()
assert isinstance(self.conv_seg, torch.nn.Conv1d)
assert self.conv_seg.in_channels == 128
assert self.conv_seg.out_channels == 20
assert self.conv_seg.kernel_size == (1, )
assert isinstance(self.pre_seg_conv, ConvModule)
assert isinstance(self.pre_seg_conv.conv, torch.nn.Conv1d)
assert self.pre_seg_conv.conv.in_channels == 128
assert self.pre_seg_conv.conv.out_channels == 128
assert self.pre_seg_conv.conv.kernel_size == (1, )
assert isinstance(self.pre_seg_conv.bn, torch.nn.BatchNorm1d)
assert self.pre_seg_conv.bn.num_features == 128
assert isinstance(self.pre_seg_conv.activate, torch.nn.ReLU)
# test forward
sa_xyz = [
torch.rand(2, 4096, 3).float().cuda(),
torch.rand(2, 1024, 3).float().cuda(),
torch.rand(2, 256, 3).float().cuda(),
torch.rand(2, 64, 3).float().cuda(),
torch.rand(2, 16, 3).float().cuda(),
]
sa_features = [
torch.rand(2, 6, 4096).float().cuda(),
torch.rand(2, 64, 1024).float().cuda(),
torch.rand(2, 128, 256).float().cuda(),
torch.rand(2, 256, 64).float().cuda(),
torch.rand(2, 512, 16).float().cuda(),
]
input_dict = dict(sa_xyz=sa_xyz, sa_features=sa_features)
seg_logits = self(input_dict)
assert seg_logits.shape == torch.Size([2, 20, 4096])
# test loss
pts_semantic_mask = torch.randint(0, 20, (2, 4096)).long().cuda()
losses = self.losses(seg_logits, pts_semantic_mask)
assert losses['loss_sem_seg'].item() > 0
# test loss with ignore_index
ignore_index_mask = torch.ones_like(pts_semantic_mask) * 20
losses = self.losses(seg_logits, ignore_index_mask)
assert losses['loss_sem_seg'].item() == 0
# test loss with class_weight
paconv_decode_head_cfg['loss_decode'] = dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=np.random.rand(20),
loss_weight=1.0)
self = build_head(paconv_decode_head_cfg)
self.cuda()
losses = self.losses(seg_logits, pts_semantic_mask)
assert losses['loss_sem_seg'].item() > 0
......@@ -159,3 +159,147 @@ def test_pointnet2_msg():
results = self.aug_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([500])
assert results[1]['semantic_mask'].shape == torch.Size([200])
def test_paconv_ssg():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
set_random_seed(0, True)
paconv_ssg_cfg = _get_segmentor_cfg(
'paconv/paconv_ssg_8x2_step_100e_s3dis_seg-3d-13class.py')
# for GPU memory consideration
paconv_ssg_cfg.backbone.num_points = (256, 64, 16, 4)
paconv_ssg_cfg.test_cfg.num_points = 32
self = build_segmentor(paconv_ssg_cfg).cuda()
points = [torch.rand(1024, 9).float().cuda() for _ in range(2)]
img_metas = [dict(), dict()]
gt_masks = [torch.randint(0, 13, (1024, )).long().cuda() for _ in range(2)]
# test forward_train
losses = self.forward_train(points, img_metas, gt_masks)
assert losses['decode.loss_sem_seg'].item() >= 0
assert losses['regularize.loss_regularize'].item() >= 0
# test forward function
set_random_seed(0, True)
data_dict = dict(
points=points, img_metas=img_metas, pts_semantic_mask=gt_masks)
forward_losses = self.forward(return_loss=True, **data_dict)
assert np.allclose(losses['decode.loss_sem_seg'].item(),
forward_losses['decode.loss_sem_seg'].item())
assert np.allclose(losses['regularize.loss_regularize'].item(),
forward_losses['regularize.loss_regularize'].item())
# test loss with ignore_index
ignore_masks = [torch.ones_like(gt_masks[0]) * 13 for _ in range(2)]
losses = self.forward_train(points, img_metas, ignore_masks)
assert losses['decode.loss_sem_seg'].item() == 0
# test simple_test
self.eval()
with torch.no_grad():
scene_points = [
torch.randn(200, 6).float().cuda() * 3.0,
torch.randn(100, 6).float().cuda() * 2.5
]
results = self.simple_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])
# test forward function calling simple_test
with torch.no_grad():
data_dict = dict(points=[scene_points], img_metas=[img_metas])
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])
# test aug_test
with torch.no_grad():
scene_points = [
torch.randn(2, 200, 6).float().cuda() * 3.0,
torch.randn(2, 100, 6).float().cuda() * 2.5
]
img_metas = [[dict(), dict()], [dict(), dict()]]
results = self.aug_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])
# test forward function calling aug_test
with torch.no_grad():
data_dict = dict(points=scene_points, img_metas=img_metas)
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])
def test_paconv_cuda_ssg():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
set_random_seed(0, True)
paconv_cuda_ssg_cfg = _get_segmentor_cfg(
'paconv/paconv_ssg_8x2_step_100e_s3dis_seg-3d-13class.py')
# for GPU memory consideration
paconv_cuda_ssg_cfg.backbone.num_points = (256, 64, 16, 4)
paconv_cuda_ssg_cfg.test_cfg.num_points = 32
self = build_segmentor(paconv_cuda_ssg_cfg).cuda()
points = [torch.rand(1024, 9).float().cuda() for _ in range(2)]
img_metas = [dict(), dict()]
gt_masks = [torch.randint(0, 13, (1024, )).long().cuda() for _ in range(2)]
# test forward_train
losses = self.forward_train(points, img_metas, gt_masks)
assert losses['decode.loss_sem_seg'].item() >= 0
assert losses['regularize.loss_regularize'].item() >= 0
# test forward function
set_random_seed(0, True)
data_dict = dict(
points=points, img_metas=img_metas, pts_semantic_mask=gt_masks)
forward_losses = self.forward(return_loss=True, **data_dict)
assert np.allclose(losses['decode.loss_sem_seg'].item(),
forward_losses['decode.loss_sem_seg'].item())
assert np.allclose(losses['regularize.loss_regularize'].item(),
forward_losses['regularize.loss_regularize'].item())
# test loss with ignore_index
ignore_masks = [torch.ones_like(gt_masks[0]) * 13 for _ in range(2)]
losses = self.forward_train(points, img_metas, ignore_masks)
assert losses['decode.loss_sem_seg'].item() == 0
# test simple_test
self.eval()
with torch.no_grad():
scene_points = [
torch.randn(200, 6).float().cuda() * 3.0,
torch.randn(100, 6).float().cuda() * 2.5
]
results = self.simple_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])
# test forward function calling simple_test
with torch.no_grad():
data_dict = dict(points=[scene_points], img_metas=[img_metas])
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])
# test aug_test
with torch.no_grad():
scene_points = [
torch.randn(2, 200, 6).float().cuda() * 3.0,
torch.randn(2, 100, 6).float().cuda() * 2.5
]
img_metas = [[dict(), dict()], [dict(), dict()]]
results = self.aug_test(scene_points, img_metas)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])
# test forward function calling aug_test
with torch.no_grad():
data_dict = dict(points=scene_points, img_metas=img_metas)
results = self.forward(return_loss=False, **data_dict)
assert results[0]['semantic_mask'].shape == torch.Size([200])
assert results[1]['semantic_mask'].shape == torch.Size([100])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment