Skip to content
Snippets Groups Projects
Commit 775cda92 authored by lenhy's avatar lenhy
Browse files

Early stopping

parent 302dc70a
Branches
No related tags found
No related merge requests found
import pykeen import pykeen
from pykeen.models import TransE, RotatE, ConvE, DistMult from pykeen.models import TransE, RotatE, ConvE, DistMult
from pykeen.stoppers import EarlyStopper
from torch.optim import Adam from torch.optim import Adam
from pykeen.training import SLCWATrainingLoop from pykeen.training import SLCWATrainingLoop
from pykeen.evaluation import RankBasedEvaluator from pykeen.evaluation import RankBasedEvaluator
...@@ -39,7 +40,9 @@ def get_model(name: str, triples, inverse_triples): ...@@ -39,7 +40,9 @@ def get_model(name: str, triples, inverse_triples):
return model_dict[name], lr_dict[name], epochs_dict[name], bs_dict[name] return model_dict[name], lr_dict[name], epochs_dict[name], bs_dict[name]
def train(model, train_data, optimizer, n_epochs: int = 20, batch_size: int = 256): def train(model, train_data, eval_data, optimizer, evaluator, n_epochs: int = 20, batch_size: int = 256):
early_stopper = EarlyStopper(model=model, evaluator= evaluator, training_triples_factory=train_data,
evaluation_triples_factory= eval_data)
training_loop = SLCWATrainingLoop( training_loop = SLCWATrainingLoop(
model=model, model=model,
triples_factory=train_data, triples_factory=train_data,
...@@ -50,14 +53,16 @@ def train(model, train_data, optimizer, n_epochs: int = 20, batch_size: int = 25 ...@@ -50,14 +53,16 @@ def train(model, train_data, optimizer, n_epochs: int = 20, batch_size: int = 25
triples_factory=train_data, triples_factory=train_data,
num_epochs=n_epochs, num_epochs=n_epochs,
batch_size=batch_size, batch_size=batch_size,
stopper=early_stopper,
) )
print(early_stopper.best_epoch)
return model return model
if __name__ == '__main__': if __name__ == '__main__':
dataset = load_data() dataset = load_data()
training_triples_factory = dataset.training training_triples_factory = dataset.training
eval_triples_factory = dataset.validation
training_inverse = training_triples_factory.clone_and_exchange_triples(training_triples_factory.mapped_triples, create_inverse_triples=True) training_inverse = training_triples_factory.clone_and_exchange_triples(training_triples_factory.mapped_triples, create_inverse_triples=True)
model_names = ["TransE", "RotatE", "DistMult", "ConvE"] model_names = ["TransE", "RotatE", "DistMult", "ConvE"]
model_results = {} model_results = {}
...@@ -68,10 +73,11 @@ if __name__ == '__main__': ...@@ -68,10 +73,11 @@ if __name__ == '__main__':
optimizer = Adam(lr=lr, params=model.get_grad_params()) optimizer = Adam(lr=lr, params=model.get_grad_params())
train(model, training_triples_factory, optimizer, n_epochs=epochs, batch_size=batch_size)
evaluator = RankBasedEvaluator() evaluator = RankBasedEvaluator()
train(model, training_triples_factory, eval_triples_factory, optimizer,
evaluator, n_epochs=epochs, batch_size=batch_size)
test_triples = dataset.testing.mapped_triples test_triples = dataset.testing.mapped_triples
additional_triples = [ additional_triples = [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment