Skip to content
Snippets Groups Projects
Commit 3c4803a6 authored by Brad Nelson's avatar Brad Nelson
Browse files

sparse F2 vector class implemeneted

parent f7dc82ac
No related branches found
No related tags found
No related merge requests found
#include "cocycle.h" #include "cocycle.h"
#include "sparsevec.h"
#include <vector> #include <vector>
#include <torch/extension.h> #include <torch/extension.h>
...@@ -6,82 +7,27 @@ namespace py = pybind11; ...@@ -6,82 +7,27 @@ namespace py = pybind11;
void Cocycle::insert(int x){ void Cocycle::insert(int x){
cochain.push_back(x); cochain.insert(x);
} }
// add x to cocycle // add x to cocycle
// IMPORTANT: this function assumes that cocycles are sorted! // IMPORTANT: this function assumes that cocycles are sorted!
void Cocycle::add(const Cocycle &x){ void Cocycle::add(const Cocycle &x){
// quick check to see if there is anything to do cochain.add(x.cochain);
if (x.cochain.size() == 0) {return;}
if (cochain.size() == 0) {cochain = x.cochain; return;}
// now we know there is something non-trivial to do
std::vector<int> tmp;
size_t i1 = 0;
size_t i2 = 0;
do {
size_t v1 = cochain[i1];
size_t v2 = x.cochain[i2];
if (v1 == v2) {
// F2 means sum is 0
i1++;
i2++;
} else if (v1 < v2) {
tmp.push_back(v1);
i1++;
} else { // v2 < v1
tmp.push_back(v2);
i2++;
}
} while (i1 < cochain.size() && i2 < x.cochain.size());
// run through rest of entries and dump in
// only one of the loops will actually do anything
while (i1 < cochain.size()) {
tmp.push_back(cochain[i1]);
i1++;
}
while (i2 < x.cochain.size()) {
tmp.push_back(x.cochain[i2]);
i2++;
}
cochain = tmp;
return; return;
} }
// take dot product with cocycle // take dot product with cocycle
// IMPORTANT: this function assumes that cocycles are sorted! // IMPORTANT: this function assumes that cocycles are sorted!
int Cocycle::dot(const Cocycle &x) const{ int Cocycle::dot(const Cocycle &x) const{
// inner product return cochain.dot(x.cochain);
// quick check to see if anything to be done
if (cochain.size() == 0 || x.cochain.size() == 0) return 0;
// loop over indices to compute size of intersection
size_t i1 = 0;
size_t i2 = 0;
size_t intersection = 0;
do {
auto v1 = cochain[i1];
auto v2 = x.cochain[i2];
if (v1 == v2) {
i1++;
i2++;
intersection++;
} else if (v1 < v2) {
i1++;
} else { // v2 < v1
i2++;
}
} while (i1 < cochain.size() && i2 < x.cochain.size());
// std::set<int> tmp;
// std::set_intersection(x.cochain.begin(), x.cochain.end(), cochain.begin(),cochain.end(),std::inserter(tmp,tmp.begin()));
// return tmp.size()%2;
return intersection % 2;
} }
int Cocycle::dim() const{ int Cocycle::dim() const{
return (cochain.size()==0) ? 0 : cochain.size()-1; return (cochain.nzinds.size()==0) ? 0 : cochain.nzinds.size()-1;
} }
void Cocycle::print(){ void Cocycle::print(){
py::print(index, " : ", cochain); py::print(index, " : ");
cochain.print();
} }
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <vector> #include <vector>
#include <cstddef> #include <cstddef>
#include "sparsevec.h"
class Cocycle{ class Cocycle{
...@@ -12,13 +13,13 @@ class Cocycle{ ...@@ -12,13 +13,13 @@ class Cocycle{
// non-zero entries // non-zero entries
// IMPORTANT: this is assumed to always be sorted! // IMPORTANT: this is assumed to always be sorted!
std::vector<int> cochain; SparseF2Vec<int> cochain;
// we should never have this // we should never have this
Cocycle() : index(-1){} Cocycle() : index(-1){}
// initializations // initializations
Cocycle(int x) : index(x) {cochain.push_back(x); } Cocycle(int x) : index(x) , cochain(x) {}
Cocycle(int x, std::vector<int> y) : index(x) , cochain(y) {} Cocycle(int x, std::vector<int> y) : index(x) , cochain(y) {}
// for debug purposes // for debug purposes
......
#ifndef _SPARSEVEC_H
#define _SPARSEVEC_H
#include <vector>
#include <cstddef>
#include <torch/extension.h>
namespace py = pybind11;
/*
Sparse vector definition
header-only file.
*/
template <typename T>
class SparseF2Vec{
public:
// non-zero entries
// IMPORTANT: this is assumed to always be sorted!
std::vector<T> nzinds;
// initialize with empty nzinds
SparseF2Vec() {}
// non-trivial initialization
SparseF2Vec(std::vector<T> x) : nzinds(x) {}
SparseF2Vec(T x) {nzinds.push_back(x);}
// insert index to nzinds
void insert(T x) {
nzinds.push_back(x);
// TODO: sort indices?
}
// add two vectors over F2
void add(const SparseF2Vec<T> &x) {
// quick check to see if there is anything to do
if (x.nzinds.size() == 0) {return;}
if (nzinds.size() == 0) {nzinds = x.nzinds; return;}
// now we know there is something non-trivial to do
std::vector<T> tmp;
size_t i1 = 0;
size_t i2 = 0;
do {
size_t v1 = nzinds[i1];
size_t v2 = x.nzinds[i2];
if (v1 == v2) {
// F2 means sum is 0
i1++;
i2++;
} else if (v1 < v2) {
tmp.push_back(v1);
i1++;
} else { // v2 < v1
tmp.push_back(v2);
i2++;
}
} while (i1 < nzinds.size() && i2 < x.nzinds.size());
// run through rest of entries and dump in
// only one of the loops will actually do anything
while (i1 < nzinds.size()) {
tmp.push_back(nzinds[i1]);
i1++;
}
while (i2 < x.nzinds.size()) {
tmp.push_back(x.nzinds[i2]);
i2++;
}
nzinds = tmp;
return;
}
// dot product of two vectors
// IMPORTANT: this function assumes that nzinds are sorted!
int dot(const SparseF2Vec<T> &x) const {
// inner product
// quick check to see if anything to be done
if (nzinds.size() == 0 || x.nzinds.size() == 0) return 0;
// loop over indices to compute size of intersection
size_t i1 = 0;
size_t i2 = 0;
size_t intersection = 0;
do {
auto v1 = nzinds[i1];
auto v2 = x.nzinds[i2];
if (v1 == v2) {
i1++;
i2++;
intersection++;
} else if (v1 < v2) {
i1++;
} else { // v2 < v1
i2++;
}
} while (i1 < nzinds.size() && i2 < x.nzinds.size());
// mod-2
return intersection % 2;
}
// debug function
void print() {
py::print(nzinds);
}
};
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment