Skip to content
Snippets Groups Projects
Commit a02ac232 authored by lenhy's avatar lenhy
Browse files

Bugfix and more hyperparameters

parent e905f9eb
No related branches found
No related tags found
No related merge requests found
import pykeen import pykeen
from pykeen.models import TransE, TransR, ConvE, DistMult from pykeen.models import TransE, RotatE, ConvE, DistMult
from torch.optim import Adam from torch.optim import Adam
from pykeen.training import SLCWATrainingLoop from pykeen.training import SLCWATrainingLoop
from pykeen.evaluation import RankBasedEvaluator from pykeen.evaluation import RankBasedEvaluator
...@@ -12,7 +12,7 @@ def get_model(name: str, triples, inverse_triples): ...@@ -12,7 +12,7 @@ def get_model(name: str, triples, inverse_triples):
model_dict = { model_dict = {
"TransE": TransE(triples_factory=triples, embedding_dim=20, "TransE": TransE(triples_factory=triples, embedding_dim=20,
scoring_fct_norm=1), scoring_fct_norm=1),
"TransR": TransR(triples_factory=triples, embedding_dim=50, "RotatE": RotatE(triples_factory=triples, embedding_dim=125,
relation_dim=50, scoring_fct_norm=1), relation_dim=50, scoring_fct_norm=1),
"ConvE": ConvE(triples_factory=inverse_triples, embedding_dim=200, "ConvE": ConvE(triples_factory=inverse_triples, embedding_dim=200,
feature_map_dropout= 0.2, input_dropout= 0.2, feature_map_dropout= 0.2, input_dropout= 0.2,
...@@ -21,11 +21,23 @@ def get_model(name: str, triples, inverse_triples): ...@@ -21,11 +21,23 @@ def get_model(name: str, triples, inverse_triples):
} }
lr_dict = { lr_dict = {
"TransE": 0.01, "TransE": 0.01,
"TransR": 0.001, "RotatE": 5.0e-6,
"ConvE": 0.001, "ConvE": 0.001,
"DistMult": 0.1, "DistMult": 0.1,
} }
return model_dict[name], lr_dict[name] epochs_dict = {
"TransE": 500,
"RotatE": 6000,
"ConvE": 500,
"DistMult": 300,
}
bs_dict = {
"TransE": 1440,
"RotatE": 100000,
"ConvE": 128,
"DistMult": 4000,
}
return model_dict[name], lr_dict[name], epochs_dict[name], bs_dict[name]
def train(model, train_data, optimizer, n_epochs: int = 20, batch_size: int = 256): def train(model, train_data, optimizer, n_epochs: int = 20, batch_size: int = 256):
...@@ -48,16 +60,16 @@ if __name__ == '__main__': ...@@ -48,16 +60,16 @@ if __name__ == '__main__':
dataset = load_data() dataset = load_data()
training_triples_factory = dataset.training training_triples_factory = dataset.training
training_inverse = training_triples_factory.clone_and_exchange_triples(training_triples_factory.mapped_triples, create_inverse_triples=True) training_inverse = training_triples_factory.clone_and_exchange_triples(training_triples_factory.mapped_triples, create_inverse_triples=True)
model_names = ["TransE", "TransR", "ConvE", "DistMult"] model_names = ["TransE", "RotatE", "DistMult", "ConvE"]
model_results = {} model_results = {}
for name in model_names: for name in model_names:
model_results[name] = {} model_results[name] = {}
print(name) print(name)
model, lr = get_model(name, training_triples_factory, training_inverse) model, lr, epochs, batch_size = get_model(name, training_triples_factory, training_inverse)
optimizer = Adam(lr=lr, params=model.get_grad_params()) optimizer = Adam(lr=lr, params=model.get_grad_params())
train(model, training_triples_factory, optimizer) train(model, training_triples_factory, optimizer, n_epochs=epochs, batch_size=batch_size)
evaluator = RankBasedEvaluator() evaluator = RankBasedEvaluator()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment