from operator import itemgetter
import matplotlib.pyplot as plt
from sklearn.utils import validation
[docs]def plot2d(model, x, y, figure, title=""):
"""
Projects the input data to two dimensions and plots it. The projection is
done using the relevances of the given glvq model.
Parameters
----------
model : GlvqModel that has relevances
(GrlvqModel,GmlvqModel,LgmlvqModel)
x : array-like, shape = [n_samples, n_features]
Input data
y : array, shape = [n_samples]
Input data target
figure : int
the figure to plot on
title : str, optional
the title to use, optional
"""
x, y = validation.check_X_y(x, y)
dim = 2
f = plt.figure(figure)
f.suptitle(title)
pred = model.predict(x)
if hasattr(model, 'omegas_'):
nb_prototype = model.w_.shape[0]
ax = f.add_subplot(1, nb_prototype + 1, 1)
ax.scatter(x[:, 0], x[:, 1], c=_to_tango_colors(y), alpha=0.5)
ax.scatter(x[:, 0], x[:, 1], c=_to_tango_colors(pred), marker='.')
ax.scatter(model.w_[:, 0], model.w_[:, 1],
c=_tango_color('aluminium', 5), marker='D')
ax.scatter(model.w_[:, 0], model.w_[:, 1],
c=_to_tango_colors(model.c_w_, 0), marker='.')
ax.axis('equal')
d = sorted([(model._compute_distance(x[y == model.c_w_[i]],
model.w_[i]).sum(), i) for i in
range(nb_prototype)], key=itemgetter(0))
idxs = list(map(itemgetter(1), d))
for i in idxs:
x_p = model.project(x, i, dim, print_variance_covered=True)
w_p = model.project(model.w_[i], i, dim)
ax = f.add_subplot(1, nb_prototype + 1, idxs.index(i) + 2)
ax.scatter(x_p[:, 0], x_p[:, 1], c=_to_tango_colors(y, 0),
alpha=0.2)
# ax.scatter(X_p[:, 0], X_p[:, 1], c=pred, marker='.')
ax.scatter(w_p[0], w_p[1],
c=_tango_color('aluminium', 5), marker='D')
ax.scatter(w_p[0], w_p[1],
c=_tango_color(i, 0), marker='.')
ax.axis('equal')
else:
ax = f.add_subplot(121)
ax.scatter(x[:, 0], x[:, 1], c=_to_tango_colors(y), alpha=0.5)
ax.scatter(x[:, 0], x[:, 1], c=_to_tango_colors(pred), marker='.')
ax.scatter(model.w_[:, 0], model.w_[:, 1],
c=_tango_color('aluminium', 5), marker='D')
ax.scatter(model.w_[:, 0], model.w_[:, 1],
c=_to_tango_colors(model.c_w_, 0), marker='.')
ax.axis('equal')
x_p = model.project(x, dim, print_variance_covered=True)
w_p = model.project(model.w_, dim)
ax = f.add_subplot(122)
ax.scatter(x_p[:, 0], x_p[:, 1], c=_to_tango_colors(y, 0), alpha=0.5)
# ax.scatter(X_p[:, 0], X_p[:, 1], c=pred, marker='.')
ax.scatter(w_p[:, 0], w_p[:, 1],
c=_tango_color('aluminium', 5), marker='D')
ax.scatter(w_p[:, 0], w_p[:, 1], s=60,
c=_to_tango_colors(model.c_w_, 0), marker='.')
ax.axis('equal')
f.show()
colors = {
"skyblue": ['#729fcf', '#3465a4', '#204a87'],
"scarletred": ['#ef2929', '#cc0000', '#a40000'],
"orange": ['#fcaf3e', '#f57900', '#ce5c00'],
"plum": ['#ad7fa8', '#75507b', '#5c3566'],
"chameleon": ['#8ae234', '#73d216', '#4e9a06'],
"butter": ['#fce94f', 'edd400', '#c4a000'],
"chocolate": ['#e9b96e', '#c17d11', '#8f5902'],
"aluminium": ['#eeeeec', '#d3d7cf', '#babdb6', '#888a85', '#555753',
'#2e3436']
}
color_names = list(colors.keys())
def _tango_color(name, brightness=0):
if type(name) is int:
if name >= len(color_names):
name = name % len(color_names)
name = color_names[name]
if name in colors:
return colors[name][brightness]
else:
raise ValueError('{} is not a valid color'.format(name))
def _to_tango_colors(elems, brightness=0):
elem_set = list(set(elems))
return [_tango_color(elem_set.index(e), brightness) for e in elems]