Skip to content
Snippets Groups Projects
Commit 6fc957df authored by lenhy's avatar lenhy
Browse files

Basic version, loads data, trains and evaluates

parent 22d69c3a
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
import click
import logging
from pathlib import Path
from dotenv import find_dotenv, load_dotenv
import pykeen.datasets
@click.command()
@click.argument('input_filepath', type=click.Path(exists=True))
@click.argument('output_filepath', type=click.Path())
def main(input_filepath, output_filepath):
""" Runs data processing scripts to turn raw data from (../raw) into
cleaned data ready to be analyzed (saved in ../processed).
"""
logger = logging.getLogger(__name__)
logger.info('making final data set from raw data')
def load_data():
data = pykeen.datasets.WN18RR()
return data
if __name__ == '__main__':
log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(level=logging.INFO, format=log_fmt)
# not used in this stub but often useful for finding various files
project_dir = Path(__file__).resolve().parents[2]
# find .env automagically by walking up directories until it's found, then
# load up the .env entries as environment variables
load_dotenv(find_dotenv())
main()
dataset = load_data()
print(len(dataset.training.mapped_triples))
\ No newline at end of file
def evaluate(evaluator, model, test_triples, additional_triples, batch_size: int = 1024):
results = evaluator.evaluate(
model=model,
mapped_triples=test_triples,
batch_size=batch_size,
additional_filter_triples=additional_triples,
)
return results
\ No newline at end of file
import pykeen
from pykeen.models import TransE
from torch.optim import Adam
from pykeen.training import SLCWATrainingLoop
from pykeen.evaluation import RankBasedEvaluator
from src.data.make_dataset import load_data
from src.models.predict_model import evaluate
def train(model, train_data, optimizer, n_epochs: int = 5, batch_size: int = 256):
training_loop = SLCWATrainingLoop(
model=model,
triples_factory=train_data,
optimizer=optimizer,
)
_ = training_loop.train(
triples_factory=train_data,
num_epochs=n_epochs,
batch_size=batch_size,
)
return model
if __name__ == '__main__':
dataset = load_data()
training_triples_factory = dataset.training
model = TransE(triples_factory=training_triples_factory)
optimizer = Adam(params=model.get_grad_params())
train(model, training_triples_factory, optimizer)
evaluator = RankBasedEvaluator()
test_triples = dataset.testing.mapped_triples[:500]
additional_triples = [
dataset.training.mapped_triples,
dataset.validation.mapped_triples,
]
results = evaluate(evaluator, model, test_triples, additional_triples)
print(f"Hits@1: {results.data[('hits_at_1', 'both', 'realistic')]}")
print(f"Hits@3: {results.data[('hits_at_3', 'both', 'realistic')]}")
print(f"Hits@10: {results.data[('hits_at_10', 'both', 'realistic')]}")
print(f"Arithmetic mean rank (MR): {results.data[('arithmetic_mean_rank', 'both', 'realistic')]}")
print(f"Inverse harmonic mean rank (MR): {results.data[('inverse_harmonic_mean_rank', 'both', 'realistic')]}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment