Source code for catalyst.contrib.callbacks.knn_metric
from typing import Dict, List, TYPE_CHECKING
from math import ceil
import numpy as np
from scipy import stats
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_score,
recall_score,
)
from sklearn.neighbors import NearestNeighbors
import torch
from catalyst.core.callback import Callback, CallbackOrder
if TYPE_CHECKING:
from catalyst.core.runner import IRunner
[docs]class KNNMetricCallback(Callback):
"""A callback that returns single metric on ``runner.on_loader_end``."""
[docs] def __init__(
self,
input_key: str = "logits",
output_key: str = "targets",
prefix: str = "knn",
num_classes: int = 2,
class_names: dict = None,
cv_loader_names: Dict[str, List[str]] = None,
metric_fn: str = "f1-score",
knn_metric: str = "euclidean",
num_neighbors: int = 5,
):
"""Returns metric value calculated using kNN algorithm.
Args:
input_key: input key to get features.
output_key: output key to get targets.
prefix: key to store in logs.
num_classes: Number of classes; must be > 1.
class_names: of indexes and class names.
cv_loader_names: dict with keys and values of loader_names
for which cross validation should be calculated.
For example {"train" : ["valid", "test"]}.
metric_fn: one of `accuracy`, `precision`, `recall`, `f1-score`.
default is `f1-score`.
knn_metric: look sklearn.neighbors.NearestNeighbors parameter.
num_neighbors: number of neighbors, default is 5.
"""
super().__init__(CallbackOrder.metric)
assert num_classes > 1, "`num_classes` should be more than 1"
metric_fns = {
"accuracy": accuracy_score,
"recall": recall_score,
"precision": precision_score,
"f1-score": f1_score,
}
assert (
metric_fn in metric_fns
), f"Metric function with value `{metric_fn}` not implemented"
self.prefix = prefix
self.features_key = input_key
self.targets_key = output_key
self.num_classes = num_classes
self.class_names = (
class_names
if class_names is not None
else [str(i) for i in range(num_classes)]
)
self.cv_loader_names = cv_loader_names
self.metric_fn = metric_fns[metric_fn]
self.knn_metric = knn_metric
self.num_neighbors = num_neighbors
self.num_folds = 1
self._reset_cache()
self._reset_sets()
def _reset_cache(self):
"""Function to reset cache for features and labels."""
self.features = []
self.targets = []
def _reset_sets(self):
"""Function to reset cache for all sets."""
self.sets = {}
def _knn(self, train_set, test_set=None):
"""Returns accuracy calculated using kNN algorithm.
Args:
train_set: dict of feature "values" and "labels" for training set.
test_set: dict of feature "values" and "labels" for test set.
Returns:
cm: tuple of lists of true & predicted classes.
"""
# if the test_set is None, we will test train_set on itself,
# in that case we need to delete the closest neighbor
leave_one_out = test_set is None
if leave_one_out:
test_set = train_set
x_train, y_train = train_set["values"], train_set["labels"]
x_test, y_test = test_set["values"], test_set["labels"]
size = len(y_train)
result = None
while result is None:
try:
y_pred = []
# fit nearest neighbors class on our train data
classifier = NearestNeighbors(
num_neighbors=self.num_neighbors + int(leave_one_out),
metric=self.knn_metric,
algorithm="brute",
)
classifier.fit(x_train, y_train)
# data could be evaluated in num_folds in order to avoid OOM
end_idx, batch_size = 0, ceil(size / self.num_folds)
for start_idx in range(0, size, batch_size):
end_idx = min(start_idx + batch_size, size)
x = x_test[start_idx:end_idx]
knn_ids = classifier.kneighbors(x, return_distance=False)
# if we predict train set on itself we have to delete 0th
# neighbor for all of the distances
if leave_one_out:
knn_ids = knn_ids[:, 1:]
# calculate the most frequent class across k neighbors
knn_classes = y_train[knn_ids]
knn_classes, _ = stats.mode(knn_classes, axis=1)
y_pred.extend(knn_classes[:, 0].tolist())
y_pred = np.asarray(y_pred)
result = (y_test, y_pred)
# this try catch block made because sometimes sets are quite big
# and it is not possible to put everything in memory, so we split
except MemoryError:
print(
f"Memory error with {self.num_folds} folds, trying more."
)
self.num_folds *= 2
result = None
return result
[docs] def on_batch_end(self, runner: "IRunner") -> None:
"""Batch end hook.
Args:
runner: current runner
"""
features: torch.Tensor = runner.output[
self.features_key
].cpu().detach().numpy()
targets: torch.Tensor = runner.input[
self.targets_key
].cpu().detach().numpy()
self.features.extend(features)
self.targets.extend(targets)
[docs] def on_loader_end(self, runner: "IRunner") -> None:
"""Loader end hook.
Args:
runner: current runner
Raises:
Warning: if targets has more classes than num classes
"""
self.features = np.stack(self.features)
self.targets = np.stack(self.targets)
if len(np.unique(self.targets)) > self.num_classes:
raise Warning("Targets has more classes than num_classes")
s = {
"values": self.features,
"labels": self.targets,
}
self.sets[runner.loader_key] = s
y_true, y_pred = self._knn(s)
loader_values = runner.loader_metrics
if self.num_classes == 2:
loader_values[self.prefix] = self.metric_fn(
y_true, y_pred, average="binary"
)
else:
values = self.metric_fn(y_true, y_pred, average=None)
loader_values[f"{self.prefix}"] = np.mean(values)
for i, value in enumerate(values):
loader_values[f"{self.prefix}/{self.class_names[i]}"] = value
self._reset_cache()
[docs] def on_epoch_end(self, runner: "IRunner") -> None:
"""Epoch end hook.
Args:
runner: current runner
"""
if self.cv_loader_names is not None:
for k, vs in self.cv_loader_names.items():
# checking for presence of subset
if k not in self.sets:
print(
f"Set `{k}` not found in the sets. "
f"Please change `cv_loader_names` parameter."
)
continue
for v in vs:
# checking for presence of subset
if v not in self.sets:
print(
f"Set `{v}` not found in the sets. "
f"Please change `cv_loader_names` parameter."
)
continue
y_true, y_pred = self._knn(self.sets[k], self.sets[v])
loader_values = runner.epoch_metrics[f"{k}_{v}_cv"]
if self.num_classes == 2:
loader_values[f"{self.prefix}"] = self.metric_fn(
y_true, y_pred, average="binary"
)
else:
values = self.metric_fn(y_true, y_pred, average=None)
loader_values[f"{self.prefix}"] = np.mean(values)
for i, value in enumerate(values):
prefix = f"{self.prefix}/{self.class_names[i]}"
loader_values[prefix] = value
self._reset_cache()
self._reset_sets()
__all__ = ["KNNMetricCallback"]