Skip to content
Snippets Groups Projects
Commit 91ee6ecc authored by lenhy's avatar lenhy
Browse files

optimize hits@k

parent febe5c66
No related branches found
No related tags found
No related merge requests found
......@@ -44,7 +44,7 @@ def get_model(name: str, triples, inverse_triples, regularizer):
def train(model, train_data, eval_data, optimizer, evaluator, n_epochs: int = 20, batch_size: int = 256):
early_stopper = EarlyStopper(model=model, evaluator= evaluator, training_triples_factory=train_data,
evaluation_triples_factory= eval_data, frequency=5, relative_delta=0.0, metric='arithmetic_mean_rank', larger_is_better=False)
evaluation_triples_factory= eval_data, frequency=5, relative_delta=0.0, metric='hits_at_k', larger_is_better=False)
training_loop = SLCWATrainingLoop(
model=model,
triples_factory=train_data,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment