diff --git a/src/models/train_model.py b/src/models/train_model.py
index e42c5c5a917c514e56278bfff1668baa2b5d9352..9844dff051ecf05495e021518af8c1270885ce9f 100644
--- a/src/models/train_model.py
+++ b/src/models/train_model.py
@@ -45,8 +45,8 @@ def get_model(name: str, triples, inverse_triples, regularizer):
         "ConvE": 0.1,
         "DistMult": 0.0,
     }
+    model = model_dict[name]
     if name == "TransE":
-        model = model_dict[name]
         model.loss.margin = 2.0
     return model, lr_dict[name], epochs_dict[name], bs_dict[name], ls_dict[name]