from operator import itemgetter
import matplotlib.pyplot as plt
from sklearn.utils import validation
[docs]def plot2d(model, x, y, figure, title=""):
    """
    Projects the input data to two dimensions and plots it. The projection is
    done using the relevances of the given glvq model.
    Parameters
    ----------
    model : GlvqModel that has relevances
        (GrlvqModel,GmlvqModel,LgmlvqModel)
    x : array-like, shape = [n_samples, n_features]
        Input data
    y : array, shape = [n_samples]
        Input data target
    figure : int
        the figure to plot on
    title : str, optional
        the title to use, optional
    """
    x, y = validation.check_X_y(x, y)
    dim = 2
    f = plt.figure(figure)
    f.suptitle(title)
    pred = model.predict(x)
    if hasattr(model, 'omegas_'):
        nb_prototype = model.w_.shape[0]
        ax = f.add_subplot(1, nb_prototype + 1, 1)
        ax.scatter(x[:, 0], x[:, 1], c=_to_tango_colors(y), alpha=0.5)
        ax.scatter(x[:, 0], x[:, 1], c=_to_tango_colors(pred), marker='.')
        ax.scatter(model.w_[:, 0], model.w_[:, 1],
                   c=_tango_color('aluminium', 5), marker='D')
        ax.scatter(model.w_[:, 0], model.w_[:, 1],
                   c=_to_tango_colors(model.c_w_, 0), marker='.')
        ax.axis('equal')
        d = sorted([(model._compute_distance(x[y == model.c_w_[i]],
                                             model.w_[i]).sum(), i) for i in
                    range(nb_prototype)], key=itemgetter(0))
        idxs = list(map(itemgetter(1), d))
        for i in idxs:
            x_p = model.project(x, i, dim, print_variance_covered=True)
            w_p = model.project(model.w_[i], i, dim)
            ax = f.add_subplot(1, nb_prototype + 1, idxs.index(i) + 2)
            ax.scatter(x_p[:, 0], x_p[:, 1], c=_to_tango_colors(y, 0),
                       alpha=0.2)
            # ax.scatter(X_p[:, 0], X_p[:, 1], c=pred, marker='.')
            ax.scatter(w_p[0], w_p[1],
                       c=_tango_color('aluminium', 5), marker='D')
            ax.scatter(w_p[0], w_p[1],
                       c=_tango_color(i, 0), marker='.')
            ax.axis('equal')
    else:
        ax = f.add_subplot(121)
        ax.scatter(x[:, 0], x[:, 1], c=_to_tango_colors(y), alpha=0.5)
        ax.scatter(x[:, 0], x[:, 1], c=_to_tango_colors(pred), marker='.')
        ax.scatter(model.w_[:, 0], model.w_[:, 1],
                   c=_tango_color('aluminium', 5), marker='D')
        ax.scatter(model.w_[:, 0], model.w_[:, 1],
                   c=_to_tango_colors(model.c_w_, 0), marker='.')
        ax.axis('equal')
        x_p = model.project(x, dim, print_variance_covered=True)
        w_p = model.project(model.w_, dim)
        ax = f.add_subplot(122)
        ax.scatter(x_p[:, 0], x_p[:, 1], c=_to_tango_colors(y, 0), alpha=0.5)
        # ax.scatter(X_p[:, 0], X_p[:, 1], c=pred, marker='.')
        ax.scatter(w_p[:, 0], w_p[:, 1],
                   c=_tango_color('aluminium', 5), marker='D')
        ax.scatter(w_p[:, 0], w_p[:, 1], s=60,
                   c=_to_tango_colors(model.c_w_, 0), marker='.')
        ax.axis('equal')
    f.show() 
colors = {
    "skyblue": ['#729fcf', '#3465a4', '#204a87'],
    "scarletred": ['#ef2929', '#cc0000', '#a40000'],
    "orange": ['#fcaf3e', '#f57900', '#ce5c00'],
    "plum": ['#ad7fa8', '#75507b', '#5c3566'],
    "chameleon": ['#8ae234', '#73d216', '#4e9a06'],
    "butter": ['#fce94f', 'edd400', '#c4a000'],
    "chocolate": ['#e9b96e', '#c17d11', '#8f5902'],
    "aluminium": ['#eeeeec', '#d3d7cf', '#babdb6', '#888a85', '#555753',
                  '#2e3436']
}
color_names = list(colors.keys())
def _tango_color(name, brightness=0):
    if type(name) is int:
        if name >= len(color_names):
            name = name % len(color_names)
        name = color_names[name]
    if name in colors:
        return colors[name][brightness]
    else:
        raise ValueError('{} is not a valid color'.format(name))
def _to_tango_colors(elems, brightness=0):
    elem_set = list(set(elems))
    return [_tango_color(elem_set.index(e), brightness) for e in elems]