diff --git a/src/models/train_model.py b/src/models/train_model.py index e42c5c5a917c514e56278bfff1668baa2b5d9352..9844dff051ecf05495e021518af8c1270885ce9f 100644 --- a/src/models/train_model.py +++ b/src/models/train_model.py @@ -45,8 +45,8 @@ def get_model(name: str, triples, inverse_triples, regularizer): "ConvE": 0.1, "DistMult": 0.0, } + model = model_dict[name] if name == "TransE": - model = model_dict[name] model.loss.margin = 2.0 return model, lr_dict[name], epochs_dict[name], bs_dict[name], ls_dict[name]