Skip to content
Snippets Groups Projects
Commit 4d37a7fc authored by lenhy's avatar lenhy
Browse files

Label smoothing

parent 43d2ed5b
Branches
No related tags found
No related merge requests found
......@@ -39,10 +39,16 @@ def get_model(name: str, triples, inverse_triples, regularizer):
"ConvE": 128,
"DistMult": 4000,
}
return model_dict[name], lr_dict[name], epochs_dict[name], bs_dict[name]
ls_dict = {
"TransE": 0.0,
"RotatE": 0.0,
"ConvE": 0.1,
"DistMult": 0.0,
}
return model_dict[name], lr_dict[name], epochs_dict[name], bs_dict[name], ls_dict[name]
def train(model, train_data, eval_data, optimizer, evaluator, n_epochs: int = 20, batch_size: int = 256):
def train(model, train_data, eval_data, optimizer, evaluator, n_epochs: int = 20, batch_size: int = 256, label_smoothing: float = 0.0):
early_stopper = EarlyStopper(model=model, evaluator= evaluator, training_triples_factory=train_data,
evaluation_triples_factory= eval_data, frequency=10, relative_delta=0.0, metric='arithmetic_mean_rank', larger_is_better=False)
training_loop = SLCWATrainingLoop(
......@@ -56,6 +62,7 @@ def train(model, train_data, eval_data, optimizer, evaluator, n_epochs: int = 20
num_epochs=n_epochs,
batch_size=batch_size,
stopper=early_stopper,
label_smoothing=label_smoothing,
)
print(early_stopper.best_epoch)
return model
......@@ -73,7 +80,7 @@ if __name__ == '__main__':
print(name)
regularizer = LpRegularizer(weight=0.0001, p=2.0)
model, lr, epochs, batch_size = get_model(name, training_triples_factory,
model, lr, epochs, batch_size, label_smoothing = get_model(name, training_triples_factory,
training_inverse, regularizer)
optimizer = Adam(lr=lr, params=model.get_grad_params())
......@@ -81,7 +88,7 @@ if __name__ == '__main__':
evaluator = RankBasedEvaluator()
train(model, training_triples_factory, eval_triples_factory, optimizer,
evaluator, n_epochs=epochs, batch_size=batch_size)
evaluator, n_epochs=epochs, batch_size=batch_size, label_smoothing=label_smoothing)
test_triples = dataset.testing.mapped_triples
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment