From 9a32cc2704332f33d95c318bbac16dbb7eac52a3 Mon Sep 17 00:00:00 2001 From: lenhy <lenhy@dtu.dk> Date: Thu, 22 Sep 2022 14:36:00 +0200 Subject: [PATCH] Bugfix --- src/models/train_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/train_model.py b/src/models/train_model.py index e42c5c5..9844dff 100644 --- a/src/models/train_model.py +++ b/src/models/train_model.py @@ -45,8 +45,8 @@ def get_model(name: str, triples, inverse_triples, regularizer): "ConvE": 0.1, "DistMult": 0.0, } + model = model_dict[name] if name == "TransE": - model = model_dict[name] model.loss.margin = 2.0 return model, lr_dict[name], epochs_dict[name], bs_dict[name], ls_dict[name] -- GitLab