From 9a32cc2704332f33d95c318bbac16dbb7eac52a3 Mon Sep 17 00:00:00 2001
From: lenhy <lenhy@dtu.dk>
Date: Thu, 22 Sep 2022 14:36:00 +0200
Subject: [PATCH] Bugfix

---
 src/models/train_model.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/models/train_model.py b/src/models/train_model.py
index e42c5c5..9844dff 100644
--- a/src/models/train_model.py
+++ b/src/models/train_model.py
@@ -45,8 +45,8 @@ def get_model(name: str, triples, inverse_triples, regularizer):
         "ConvE": 0.1,
         "DistMult": 0.0,
     }
+    model = model_dict[name]
     if name == "TransE":
-        model = model_dict[name]
         model.loss.margin = 2.0
     return model, lr_dict[name], epochs_dict[name], bs_dict[name], ls_dict[name]
 
-- 
GitLab