Skip to content
Snippets Groups Projects
Commit 5ee75ffa authored by lenhy's avatar lenhy
Browse files

More models

parent 6fc957df
No related branches found
No related tags found
No related merge requests found
import pykeen import pykeen
from pykeen.models import TransE from pykeen.models import TransE, TransR, ConvE, DistMult
from torch.optim import Adam from torch.optim import Adam
from pykeen.training import SLCWATrainingLoop from pykeen.training import SLCWATrainingLoop
from pykeen.evaluation import RankBasedEvaluator from pykeen.evaluation import RankBasedEvaluator
...@@ -8,6 +8,16 @@ from src.data.make_dataset import load_data ...@@ -8,6 +8,16 @@ from src.data.make_dataset import load_data
from src.models.predict_model import evaluate from src.models.predict_model import evaluate
def get_model(name: str, triples, inverse_triples):
model_dict = {
"TransE": TransE(triples_factory=triples),
"TransR": TransR(triples_factory=triples),
"ConvE": ConvE(triples_factory=inverse_triples),
"DistMult": DistMult(triples_factory=triples),
}
return model_dict[name]
def train(model, train_data, optimizer, n_epochs: int = 5, batch_size: int = 256): def train(model, train_data, optimizer, n_epochs: int = 5, batch_size: int = 256):
training_loop = SLCWATrainingLoop( training_loop = SLCWATrainingLoop(
model=model, model=model,
...@@ -27,8 +37,13 @@ def train(model, train_data, optimizer, n_epochs: int = 5, batch_size: int = 256 ...@@ -27,8 +37,13 @@ def train(model, train_data, optimizer, n_epochs: int = 5, batch_size: int = 256
if __name__ == '__main__': if __name__ == '__main__':
dataset = load_data() dataset = load_data()
training_triples_factory = dataset.training training_triples_factory = dataset.training
training_inverse = training_triples_factory.clone_and_exchange_triples(training_triples_factory.mapped_triples, create_inverse_triples=True)
model = TransE(triples_factory=training_triples_factory) model_names = ["TransE", "TransR", "ConvE", "DistMult"]
model_results = {}
for name in model_names:
model_results[name] = {}
print(name)
model = get_model(name, training_triples_factory, training_inverse)
optimizer = Adam(params=model.get_grad_params()) optimizer = Adam(params=model.get_grad_params())
...@@ -36,15 +51,23 @@ if __name__ == '__main__': ...@@ -36,15 +51,23 @@ if __name__ == '__main__':
evaluator = RankBasedEvaluator() evaluator = RankBasedEvaluator()
test_triples = dataset.testing.mapped_triples[:500] test_triples = dataset.testing.mapped_triples
additional_triples = [ additional_triples = [
dataset.training.mapped_triples, dataset.training.mapped_triples,
dataset.validation.mapped_triples, dataset.validation.mapped_triples,
] ]
results = evaluate(evaluator, model, test_triples, additional_triples) results = evaluate(evaluator, model, test_triples, additional_triples)
print(f"Hits@1: {results.data[('hits_at_1', 'both', 'realistic')]}") model_results[name]["Hits@1"] = results.data[('hits_at_1', 'both', 'realistic')]
print(f"Hits@3: {results.data[('hits_at_3', 'both', 'realistic')]}") model_results[name]["Hits@3"] = results.data[('hits_at_3', 'both', 'realistic')]
print(f"Hits@10: {results.data[('hits_at_10', 'both', 'realistic')]}") model_results[name]["Hits@10"] = results.data[('hits_at_10', 'both', 'realistic')]
print(f"Arithmetic mean rank (MR): {results.data[('arithmetic_mean_rank', 'both', 'realistic')]}") model_results[name]["MR"] = results.data[('arithmetic_mean_rank', 'both', 'realistic')]
print(f"Inverse harmonic mean rank (MR): {results.data[('inverse_harmonic_mean_rank', 'both', 'realistic')]}") model_results[name]["MMR"] = results.data[('inverse_harmonic_mean_rank', 'both', 'realistic')]
\ No newline at end of file
for name in model_names:
print(name)
print(f"Hits@1: {model_results[name]['Hits@1']}")
print(f"Hits@3: {model_results[name]['Hits@3']}")
print(f"Hits@10: {model_results[name]['Hits@5']}")
print(f"Arithmetic mean rank (MR): {model_results[name]['MR']}")
print(f"Inverse harmonic mean rank (MR): {model_results[name]['MRR']}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment