Skip to content
Snippets Groups Projects
Unverified Commit a800db23 authored by WRH's avatar WRH Committed by GitHub
Browse files

[Fix] Fix compiling error on windows (#766)

* use type long long and dynamic memory allocation

* use int64_t instead of long long
parent fc9e0d9d
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@ All Rights Reserved 2019-2020.
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <cstdint>
#include <vector>
#define CHECK_CUDA(x) \
......@@ -103,7 +104,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep,
int boxes_num = boxes.size(0);
const float *boxes_data = boxes.data_ptr<float>();
long *keep_data = keep.data_ptr<long>();
int64_t *keep_data = keep.data_ptr<int64_t>();
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
......@@ -124,8 +125,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep,
cudaFree(mask_data);
unsigned long long remv_cpu[col_blocks];
memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
unsigned long long *remv_cpu = new unsigned long long[col_blocks]();
int num_to_keep = 0;
......@@ -141,6 +141,7 @@ int nms_gpu(at::Tensor boxes, at::Tensor keep,
}
}
}
delete[] remv_cpu;
if (cudaSuccess != cudaGetLastError()) printf("Error!\n");
return num_to_keep;
......@@ -157,7 +158,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep,
int boxes_num = boxes.size(0);
const float *boxes_data = boxes.data_ptr<float>();
long *keep_data = keep.data_ptr<long>();
int64_t *keep_data = keep.data_ptr<int64_t>();
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
......@@ -178,8 +179,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep,
cudaFree(mask_data);
unsigned long long remv_cpu[col_blocks];
memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
unsigned long long *remv_cpu = new unsigned long long[col_blocks]();
int num_to_keep = 0;
......@@ -195,6 +195,7 @@ int nms_normal_gpu(at::Tensor boxes, at::Tensor keep,
}
}
}
delete[] remv_cpu;
if (cudaSuccess != cudaGetLastError()) printf("Error!\n");
return num_to_keep;
......
......@@ -13,7 +13,7 @@ All Rights Reserved 2019-2020.
//#define DEBUG
const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
const float EPS = 1e-8;
__device__ const float EPS = 1e-8;
struct Point {
float x, y;
__device__ Point() {}
......
......@@ -5,6 +5,7 @@
#include <stdlib.h>
#include <assert.h>
#include <cmath>
#include <cstdint>
#include <vector>
#include <cuda.h>
#include <cuda_runtime.h>
......@@ -49,7 +50,7 @@ __global__ void assign_score_withk_forward_kernel(const int B, const int N0, con
const float* points,
const float* centers,
const float* scores,
const long* knn_idx,
const int64_t* knn_idx,
float* output) {
// ----- parallel loop for B, N1, K and O ---------
......@@ -82,7 +83,7 @@ __global__ void assign_score_withk_backward_points_kernel(const int B, const int
const int K, const int O, const int aggregate,
const float* grad_out,
const float* scores,
const long* knn_idx,
const int64_t* knn_idx,
float* grad_points,
float* grad_centers) {
......@@ -116,7 +117,7 @@ __global__ void assign_score_withk_backward_scores_kernel(const int B, const int
const float* grad_out,
const float* points,
const float* centers,
const long* knn_idx,
const int64_t* knn_idx,
float* grad_scores) {
// ----- parallel loop for B, N, K, M ---------
......@@ -156,7 +157,7 @@ void assign_score_withk_forward_wrapper(int B, int N0, int N1, int M, int K, int
const float* points_data = points.data_ptr<float>();
const float* centers_data = centers.data_ptr<float>();
const float* scores_data = scores.data_ptr<float>();
const long* knn_idx_data = knn_idx.data_ptr<long>();
const int64_t* knn_idx_data = knn_idx.data_ptr<int64_t>();
float* output_data = output.data_ptr<float>();
dim3 blocks(DIVUP(B*O*N1*K, THREADS_PER_BLOCK));
......@@ -191,7 +192,7 @@ void assign_score_withk_backward_wrapper(int B, int N0, int N1, int M, int K, in
const float* points_data = points.data_ptr<float>();
const float* centers_data = centers.data_ptr<float>();
const float* scores_data = scores.data_ptr<float>();
const long* knn_idx_data = knn_idx.data_ptr<long>();
const int64_t* knn_idx_data = knn_idx.data_ptr<int64_t>();
float* grad_points_data = grad_points.data_ptr<float>();
float* grad_centers_data = grad_centers.data_ptr<float>();
float* grad_scores_data = grad_scores.data_ptr<float>();
......
......@@ -14,7 +14,8 @@ void dynamic_voxelize_kernel(const torch::TensorAccessor<T, 2> points,
const int NDim) {
const int ndim_minus_1 = NDim - 1;
bool failed = false;
int coor[NDim];
// int coor[NDim];
int* coor = new int[NDim]();
int c;
for (int i = 0; i < num_points; ++i) {
......@@ -37,6 +38,7 @@ void dynamic_voxelize_kernel(const torch::TensorAccessor<T, 2> points,
}
}
delete[] coor;
return;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment