Skip to content
Snippets Groups Projects
Commit e72fd713 authored by lenhy's avatar lenhy
Browse files

hyperparameters changed

parent 4221ef9a
Branches
No related tags found
No related merge requests found
......@@ -25,7 +25,7 @@ def get_model(name: str, triples, inverse_triples, regularizer):
"TransE": 0.01,
"RotatE": 0.01,
"ConvE": 0.001,
"DistMult": 0.1,
"DistMult": 0.01,
}
epochs_dict = {
"TransE": 500,
......@@ -44,7 +44,7 @@ def get_model(name: str, triples, inverse_triples, regularizer):
def train(model, train_data, eval_data, optimizer, evaluator, n_epochs: int = 20, batch_size: int = 256):
early_stopper = EarlyStopper(model=model, evaluator= evaluator, training_triples_factory=train_data,
evaluation_triples_factory= eval_data, frequency=5, relative_delta=0.0, metric='hits_at_k', larger_is_better=True)
evaluation_triples_factory= eval_data, frequency=10, relative_delta=0.0, metric='arithmetic_mean_rank', larger_is_better=False)
training_loop = SLCWATrainingLoop(
model=model,
triples_factory=train_data,
......@@ -72,7 +72,7 @@ if __name__ == '__main__':
model_results[name] = {}
print(name)
regularizer = LpRegularizer(weight=0.0001, p=2.0)
regularizer = LpRegularizer(weight=0.001, p=2.0)
model, lr, epochs, batch_size = get_model(name, training_triples_factory,
training_inverse, regularizer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment