# Copyright (c) 2022 Panagiotis Anagnostou
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Interactive visualization module for the algorithms members of the HiPart
package that utilise one decomposition method to one dimension to split the
data.
@author: Panagiotis Anagnostou
"""
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
from KDEpy import FFTKDE
from HiPart.clustering import DePDDP
from HiPart.clustering import IPDDP
from HiPart.clustering import KMPDDP
from HiPart.clustering import PDDP
from tempfile import NamedTemporaryFile
import dash
import HiPart
import json
import matplotlib
import numpy as np
import os
import pandas as pd
import pickle
import plotly.subplots as subplots
import plotly.express as px
import plotly.graph_objects as go
import signal
import socket
import statsmodels.api as sm
import webbrowser
external_stylesheets = [os.path.dirname(HiPart.__file__) + "/assets/int_viz.css"]
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
port = -1
[docs]
def main(inputData):
"""
The main function of the interactive visualization.
Parameters
----------
inputData : dePDDP or iPDDP or kM_PDDP or PDDP object
The object to be visualized.
Returns
-------
obj : dePDDP or iPDDP or kM_PDDP or PDDP object
The manipulated object after the execution of the interactive
visualization. (Currently not working correctly)
"""
# Compatibility check of the input data with the interactive visualization
if not (
isinstance(inputData, DePDDP)
or isinstance(inputData, IPDDP)
or isinstance(inputData, KMPDDP)
or isinstance(inputData, PDDP)
):
raise TypeError(
"""inputData : can only be instances of classes that belong to
PDDP based algorithm ('dePDDP', 'iPDDP', 'kM-PDDP, PDDP') included
in HiPart package."""
)
if not inputData.visualization_utility:
raise ValueError(
"""The visualization of the data cannot be achieved because the
visualization_utility is False."""
)
# Create temp files for smooth execution of the interactive visualization
tmpFilesWrapers = {
"input_object": NamedTemporaryFile("wb+", suffix=".inv", delete=False),
"new_input_object": NamedTemporaryFile("wb+", suffix=".inv", delete=False),
"figure_path": NamedTemporaryFile("wb+", suffix=".png", delete=False),
}
tmpFileNames = {
"input_object": tmpFilesWrapers["input_object"].name,
"new_input_object": tmpFilesWrapers["new_input_object"].name,
"figure_path": tmpFilesWrapers["figure_path"].name,
}
# Load the necessary data on the temp files
with open(tmpFileNames["input_object"], "wb") as input_object_file:
pickle.dump(inputData, input_object_file)
with open(tmpFileNames["new_input_object"], "wb") as new_input_object_file:
pickle.dump(inputData, new_input_object_file)
# Create and run the server app
app_layout(app=app, tmpFileNames=tmpFileNames)
port = _next_free_port(8050)
webbrowser.open("http://localhost:" + str(port), new=2)
app.run_server(debug=False, port=port)
# Recall the object to return before temp file deletion
with open(tmpFileNames["new_input_object"], "rb") as obj_file:
obj = pickle.load(obj_file)
# Close and delete all temp files
tmpFilesWrapers["input_object"].close()
tmpFilesWrapers["new_input_object"].close()
tmpFilesWrapers["figure_path"].close()
os.unlink(tmpFileNames["input_object"])
os.unlink(tmpFileNames["new_input_object"])
os.unlink(tmpFileNames["figure_path"])
assert not os.path.exists(tmpFileNames["input_object"])
assert not os.path.exists(tmpFileNames["new_input_object"])
assert not os.path.exists(tmpFileNames["figure_path"])
return obj
def _next_free_port(port, max_port=65535):
"""
Find the first port to use after the port given as input.
Parameters
----------
port : int
The first port to check.
max_port : TYPE, optional
The maximum numbered port to check. The default is 65535.
Raises
------
IOError
For no available ports to use.
Returns
-------
port : int
The number of the first available port.
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while port <= max_port:
try:
sock.bind(("", port))
sock.close()
return port
except OSError:
port += 1
raise IOError("no free ports")
# %% Callbacks
[docs]
@app.callback(
Output("splitpoint_Main", "children"),
Output("splitView", "max"),
Output("splitView", "marks"),
Output("splitpoint_Manipulation", "value"),
Output("splitpoint_Manipulation", "min"),
Output("splitpoint_Manipulation", "max"),
Output("splitpoint_Manipulation", "marks"),
State("cache_dump", "data"),
State("splitpoint_Scatter", "figure"),
Input("splitView", "value"),
State("splitView", "max"),
State("splitView", "marks"),
Input("splitpoint_Manipulation", "value"),
State("splitpoint_Manipulation", "min"),
State("splitpoint_Manipulation", "max"),
State("splitpoint_Manipulation", "marks"),
Input("splitpoint_Man_apply", "n_clicks"),
prevent_initial_call=True,
)
def Splitpoint_Manipulation_Callback(
data,
current_figure,
split_number,
maximum_number_splits,
split_marks,
splitpoint_position,
splitpoint_minimum,
splitpoint_max,
splitpoint_marks,
apply_button,
):
"""
Function triggered by the change in value on the Input elements. The
triggers are:
1. the change in the split of the data.
2. the changes in the split-point.
3. the apply button for the change to the new split-point (partial
algorithm execution from the selected split).
The rest of the function's inputs are inputs that can't trigger this
callback but their data are necessary for the execution of this callback.
Parameters
----------
splitpoint_marks
data : dict
The paths of the temporary files created for the execution of the
interactive visualization.
current_figure : dict
A dictionary created from the plotly express object plots. It is used
for the manipulation of the current figure.
split_number : int
The value of the split to project (or is projected, depending on the
callback context triggered state).
maximum_number_splits : int
The number of splits there are in the model created/manipulated.
split_marks : dict
The split number is the key of the dictionary and the assigned value
is the value of the dictionary.
splitpoint_position : float
The current position of the shape represents the split-point of the
currently selected split, extracted from the split-point slider.
splitpoint_minimum : float
The minimum value that can be assigned to the split-point, extracted
from the split-point slider.
splitpoint_max : float
The maximum value that can be assigned to the split-point, extracted
from the split-point slider.
apply_button : int
The number of clicks of the apply button (Not needed in the function`s
execution, but necessary for the callback definition).
Returns
-------
figure : dash.dcc.Graph
A figure that can be integrated at dash`s Html components.
splitMax : float
The new value for the maximum split number.
splitMarks : float
The changed marks the spit can utilize as values.
splitpoint : float
The newly generated split-point by the callback.
splitPMin : int
The minimum value the split-point can take.
splitPMax : int
The maximum value the split-point can take.
"""
callback_context = dash.callback_context
callback_ID = callback_context.triggered[0]["prop_id"].split(".")[0]
# load cashed temp files
data = json.loads(data)
# Basic check on the callback trigger, based on the dash html elements ID.
if callback_ID == "splitView":
(
data_matrix,
splitpoint,
internal_nodes,
number_of_nodes,
) = _data_preparation(data["new_input_object"], split_number)
# ensure correct values for the sliders
splitPMarks = {
splitpoint: {
"label": "Generated Split-point",
"style": {"color": "#77b0b1"},
}
}
splitMax = maximum_number_splits
splitMarks = split_marks
splitPMin = data_matrix["PC1"].min()
splitPMax = data_matrix["PC1"].max()
# create visualization points
category_order = {
"cluster": [str(i) for i in range(len(np.unique(data_matrix["cluster"])))]
}
color_map = matplotlib.cm.get_cmap("tab20", number_of_nodes)
colList = {str(i): _convert_to_hex(color_map(i)) for i in range(color_map.N)}
with open(data["new_input_object"], "rb") as cls_file:
cls = pickle.load(cls_file)
# Check data compatibility with the function
if isinstance(cls, DePDDP):
current_figure = _int_make_scatter_n_hist(
data_matrix,
splitpoint,
cls.bandwidth_scale,
category_order,
colList,
)
else:
current_figure = _int_make_scatter_n_marginal_scatter(
data_matrix, splitpoint, category_order, colList
)
elif callback_ID == "splitpoint_Manipulation":
# reconstruct the already created figure as figure from dict
current_figure = go.Figure(
data=current_figure["data"],
layout=go.Layout(current_figure["layout"]),
)
# update the splitpoint shape location
current_figure.update_shapes(
{"x0": splitpoint_position, "x1": splitpoint_position}
)
# ensure correct values for the sliders
splitMax = maximum_number_splits
splitMarks = split_marks
splitpoint = splitpoint_position
splitPMin = splitpoint_minimum
splitPMax = splitpoint_max
splitPMarks = splitpoint_marks
elif callback_ID == "splitpoint_Man_apply":
# execution of the split-point manipulation algorithmically
with open(data["new_input_object"], "rb") as cls_file:
cls = pickle.load(cls_file)
cls = _recalculate_after_spchange(cls, split_number, splitpoint_position)
with open(data["new_input_object"], "wb") as cls_file:
pickle.dump(cls, cls_file)
# reconstruction of the figure and its slider from scratch
data_matrix, splitpoint, internal_nodes, number_of_nodes = _data_preparation(
data["new_input_object"], split_number
)
# ensure correct values for the sliders
splitMax = len(internal_nodes) - 1
splitMarks = {str(i): str(i) for i in range(len(internal_nodes))}
splitPMin = data_matrix["PC1"].min()
splitPMax = data_matrix["PC1"].max()
splitPMarks = {
splitpoint: {
"label": "Generated Split-point",
"style": {"color": "#77b0b1"},
}
}
# create visualization points
order = {
"cluster": [str(i) for i in range(len(np.unique(data_matrix["cluster"])))]
}
map = matplotlib.cm.get_cmap("tab20", number_of_nodes)
color_list = {str(i): _convert_to_hex(map(i)) for i in range(map.N)}
with open(data["new_input_object"], "rb") as obj_file:
obj = pickle.load(obj_file)
# Check data compatibility with the function
if isinstance(obj, DePDDP):
current_figure = _int_make_scatter_n_hist(
data_matrix,
splitpoint,
obj.bandwidth_scale,
order,
color_list,
)
else:
current_figure = _int_make_scatter_n_marginal_scatter(
data_matrix, splitpoint, order, color_list
)
else:
# reconstruct the already created figure as figure from dict
current_figure = go.Figure(
data=current_figure["data"], layout=go.Layout(current_figure["layout"])
)
# ensure correct values for the sliders
splitMax = maximum_number_splits
splitMarks = split_marks
splitpoint = splitpoint_position
splitPMin = splitpoint_minimum
splitPMax = splitpoint_max
splitPMarks = splitpoint_marks
return (
dcc.Graph(
id="splitpoint_Scatter",
figure=current_figure,
config={"displayModeBar": False},
),
splitMax,
splitMarks,
splitpoint,
splitPMin,
splitPMax,
splitPMarks,
)
# %% Split-point Manipulation page
def _Splitpoint_Manipulation(object_path):
"""
Scatter plot with split-point visualization creation function. This
function is used on the initial visualization of the data and can be
accessed throughout the execution of the interactive visualization server.
Parameters
----------
object_path : str
The location of the temporary file containing the pickle dump of the
object we want to visualize.
"""
# get the necessary data for the visualization
(
data_matrix,
splitpoint,
internal_nodes,
number_of_nodes,
) = _data_preparation(object_path, 0)
# create an ordering for the legend of the visualization
category_order = {"cluster": [str(i) for i in np.unique(data_matrix["cluster"])]}
# generate the colors to be used
color_map = matplotlib.cm.get_cmap("tab20", number_of_nodes)
colList = {str(i): _convert_to_hex(color_map(i)) for i in range(color_map.N)}
with open(object_path, "rb") as obj_file:
obj = pickle.load(obj_file)
# Check data compatibility with the function
if isinstance(obj, DePDDP):
figure = _int_make_scatter_n_hist(
data_matrix,
splitpoint,
obj.bandwidth_scale,
category_order,
colList,
)
else:
figure = _int_make_scatter_n_marginal_scatter(
data_matrix,
splitpoint,
category_order,
colList,
)
# Markdown description of the figure
description = dcc.Markdown(
_message_center("des:splitpoitn_man", object_path),
style={
"text-align": "left",
"margin": "-20px 0px 0px 0px",
},
)
# Create the split number slider
splitStep = html.Div(
[
"Select split: ",
dcc.Slider(
id="splitView",
min=0,
max=len(internal_nodes) - 1,
value=0,
marks={str(i): str(i) for i in range(len(internal_nodes))},
step=1,
),
],
style={
"text-align": "left",
"width": "740px",
"padding": "20px 80px 20px 80px",
},
)
# Create the main figure of the plot
scatter = html.Div(
id="splitpoint_Main",
children=[
dcc.Graph(
id="splitpoint_Scatter",
figure=figure,
config={"displayModeBar": False, "autosizable": True},
)
],
style={"margin": "-20px 0px -30px 0px"},
)
# Create the split-point slider
splitSlider = html.Div(
[
"Select splitpoint: ",
dcc.Slider(
id="splitpoint_Manipulation",
min=data_matrix["PC1"].min() - 0.001,
max=data_matrix["PC1"].max() + 0.001,
marks={
splitpoint: {
"label": "Generated Splitpoint",
"style": {"color": "#77b0b1"},
}
},
step=0.001,
value=splitpoint,
included=False,
),
],
style={
"text-align": "left",
"width": "627px",
"margin": "0px",
"padding": "25px 0px 0px 112px",
},
)
# Create the apply button for the manipulation of the split-point
applyButton = html.Button("Apply", id="splitpoint_Man_apply", n_clicks=0)
return html.Div(
[
description,
splitStep,
scatter,
splitSlider,
applyButton,
]
)
# %% Utilities of the interactive visualization
def _data_preparation(object_path, splitVal):
"""
Generate the necessary data for all the visualizations.
Parameters
----------
object_path : str
The location of the temporary file containing the pickle dump of the
object we want to visualize.
splitVal : int
The serial number of split that want to extract data from.
Returns
-------
data_matrix : pandas.core.frame.DataFrame
splitpoint : int
internal_nodes : list
number_of_nodes : int
"""
# load the data from the temp files
with open(object_path, "rb") as obj_file:
tree = pickle.load(obj_file).tree
# Find the clusters from the algorithms tree
clusters = tree.leaves()
clusters = sorted(clusters, key=lambda x: x.identifier)
root = tree.get_node(0)
# match the points to their respective cluster
cluster_map = np.zeros(len(root.data["indices"]))
for i in clusters:
cluster_map[i.data["indices"]] = str(int(i.data["color_key"]))
# list of all the tree's nodes
node_dictionary = tree.nodes
# Search for the internal nodes (splits) of the tree with the use of the
# clusters determined above
number_of_nodes = len(list(node_dictionary.keys()))
leaf_node_list = [j.identifier for j in clusters]
internal_nodes = [i for i in range(number_of_nodes) if not (i in leaf_node_list)]
# When nothing din the tree but the root insert the root as internal
# node (tree theory)
if len(internal_nodes) == 0:
internal_nodes = [0]
# based on the splitVal imported find the node to split
node_to_visualize = (
tree.get_node(internal_nodes[splitVal])
if len(internal_nodes) != 0
else tree.get_node(0)
)
# create a data matrix containing the 1st and 2nd principal components and
# each clusters respective color key
data_matrix = pd.DataFrame(
{
"PC1": node_to_visualize.data["projection"][:, 0],
"PC2": node_to_visualize.data["projection"][:, 1],
"cluster": cluster_map[node_to_visualize.data["indices"]],
}
)
data_matrix["cluster"] = data_matrix["cluster"].astype(int).astype(str)
# determine the split-point value
splitpoint = node_to_visualize.data["splitpoint"]
return data_matrix, splitpoint, internal_nodes, number_of_nodes
def _convert_to_hex(rgba_color):
"""
Convert the color encoding from RGBa to hexadecimal for integration with
the CSS.
Parameters
----------
rgba_color : tuple
A tuple of floats containing the RGBa values.
Returns
-------
str
The hexadecimal value of the color in question.
"""
red = int(rgba_color[0] * 255)
green = int(rgba_color[1] * 255)
blue = int(rgba_color[2] * 255)
return "#{r:02x}{g:02x}{b:02x}".format(r=red, g=green, b=blue)
def _message_center(message_id, object_path):
"""
This function is responsible for saving and returning all the messages seen
in the interactive visualization.
Parameters
----------
object_path : TYPE
The absolute path of the object the interactive visualization
visualizes.
message_id : TYPE
The id of the message to be returned.
Returns
-------
message : str
The requested message.
"""
with open(object_path, "rb") as obj_file:
obj = pickle.load(obj_file)
obj_name = str(obj.__class__).split(".")[-1].split("'")[0]
msg = ""
if message_id == "des:main_cluser":
msg = (
"""
### """
+ obj_name
+ """: Basic visualization
This is the basic visualization for the clustering of the input data.
This figure is generated by visualizing the first two components of the
decomposition method used by the """
+ obj_name
+ """ algorithm. The
colors on visualization represent each of the separate clusters
extracted by the execution of the """
+ obj_name
+ """ algorithm. It is
important to notice that each color represents the same cluster
throughout the execution of the interactive visualization (the colors
on the clusters only change with the manipulation of the execution of
the algorithm).
"""
)
elif message_id == "des:splitpoitn_man":
msg = (
"""
### """
+ obj_name
+ """: Splitpoint Manipulation
On this page, you can manipulate the split point of the
"""
+ obj_name
+ """ algorithm. The process is similar to all the
algorithms, members of the HiPart package.
For the split-point manipulation, the top section of the figure is the
selection of split to manipulate. This can be done with the help of the
top sliding bar. The numbers below the bar each time represent the
serial number of the split to manipulate. Zero is the first split of
the data set. It is necessary to notice that the manipulation of the
split point because of the nature of the execution must start from the
earliest to the latest split otherwise, the manipulation will be lost.
The middle section of the figure visualizes the data with the
utilization of the decomposition technique used to execute the
"""
+ obj_name
+ """ algorithm. The colors on the scatter plot
represent the final clusters of the input dataset. The red vertical
line represents the split-point of the data for the current split of
the algorithm. """
)
if obj_name == "dePDDP":
msg += """The marginal plot for the x-axis of this visualization
represents the density of the data. The reason behind that is to
visualize the information the algorithm has to split the data.
"""
else:
msg += """The marginal plot for the x-axis of this visualization is
one dimension scatter plot of the split, the data of which we
visualize. The reason behind that is to visualize the information
the algorithm has to split the data.
"""
msg += """
Finally, the bottom section of the figure allows the manipulation of
the split-point with the utilization of a continuous sliding bar. By
changing positions on the node of the sliding-bar we can see
corresponding movements on the vertical red line that represents the
split-point of the split.
The apply button, at the bottom of the page, applies the manipulated
split-point and re-executes the algorithm for the rest of the splits
that appeared after the currently selected one.
"""
return msg
# %% app construction page
[docs]
def app_layout(app, tmpFileNames):
"""
Basic interface creation for the interactive application. The given inputs
let the user manage the application correctly.
Parameters
----------
app : dash.Dash
The application we want to create the layout on.
tmpFileNames : dict
A dictionary with the names (paths) of the temporary files needed for
the execution of the algorithm.
"""
app.layout = html.Div(
children=[
# ----------------------------------------------------------------
# Head of the interactive visualization
dcc.Location(id="url", refresh=False),
html.Div(dcc.Link("X", id="shutdown", href="/shutdown")),
html.H1("HiPart: Interactive Visualisation"),
html.Ul(
children=[
html.Li(
dcc.Link("Clustering results", href="/clustering_results"),
style={"display": "inline", "margin": "0px 5px"},
),
html.Li(
dcc.Link(
"Split Point Manipulation", href="/splitpoint_manipulation"
),
style={"display": "inline", "margin": "0px 5px"},
),
# html.Li(dcc.Link('Delete Split', href='/delete_split'),
# style={'display': 'inline', 'margin': '0px 5px'})
],
style={"margin": "60px 0px 30px 0px"},
),
html.Br(),
# ----------------------------------------------------------------
# Main section of the interactive visualization
html.Div(id="figure_panel"),
# ----------------------------------------------------------------
# cached data container, local on the browser
dcc.Store(id="cache_dump", data=str(json.dumps(tmpFileNames))),
],
style={
"min-height": "100%",
"width": "900px",
"text-align": "center",
"margin": "auto",
},
)
def _shutdown():
"""
Server shutdown function, from visual environment.
"""
# func = request.environ.get("werkzeug.server.shutdown")
# if func is None:
# raise RuntimeError("Not running with the Werkzeug Server")
# func()
os.kill(os.getpid(), signal.SIGINT)
def _Cluster_Scatter_Plot(object_path):
"""
Simple scatter plot creation function. This function is used on the
initial visualization of the data and can be accessed throughout the
execution of the interactive visualization server.
Parameters
----------
object_path : str
The location of the temporary file containing the pickle dump of the
object we want to visualize.
"""
# get the necessary data for the visualization
(
data_matrix,
_,
_,
number_of_nodes,
) = _data_preparation(object_path, 0)
# create scatter plot with the split-point shape
color_map = matplotlib.cm.get_cmap("tab20", number_of_nodes)
colList = {str(i): _convert_to_hex(color_map(i)) for i in range(color_map.N)}
category_order = {
"cluster": [str(i) for i in range(len(np.unique(data_matrix["cluster"])))]
}
# create scatter plot
figure = px.scatter(
data_matrix,
x="PC1",
y="PC2",
color="cluster",
category_orders=category_order,
color_discrete_map=colList,
)
# reform visualization
figure.update_layout(width=850, height=650, plot_bgcolor="#fff")
figure.update_traces(
mode="markers", marker=dict(size=4), hovertemplate=None, hoverinfo="skip"
)
figure.update_xaxes(
fixedrange=True,
showgrid=True,
gridwidth=1,
gridcolor="#aaa",
zeroline=True,
zerolinewidth=1,
zerolinecolor="#aaa",
)
figure.update_yaxes(
fixedrange=True,
showgrid=True,
gridwidth=1,
gridcolor="#aaa",
zeroline=True,
zerolinewidth=1,
zerolinecolor="#aaa",
)
# Markdown description of the figure
description = dcc.Markdown(
_message_center("des:main_cluser", object_path),
style={
"text-align": "left",
"margin": "-20px 0px 0px 0px",
},
)
return html.Div(
[description, dcc.Graph(figure=figure, config={"displayModeBar": False})]
)
# %% Basic plots
def _int_make_scatter_n_hist(
data_matrix, splitPoint, bandwidth_scale, category_order, colList
):
"""
Create two plots that visualize the data on the second plot and on the
first give their density representation on the first principal component.
Parameters
----------
data_matrix : pandas.core.frame.DataFrame
The projection of the data on the first two Principal Components as
columns "PC1" and "PC2" and the final cluster each sample belong at
the end of the algorithm's execution as column "cluster".
splitPoint : int
The values of the point the data are split for this plot.
bandwidth_scale
Standard deviation scaler for the density approximation. Allowed values
are in the (0,1).
category_order : dict
The order of witch to show the clusters, contained in the
visualization, on the legend of the plot.
colList : dict
A dictionary containing the color of each cluster (key) as RGBa tuple
(value).
Returns
-------
fig : plotly.graph_objs._figure.Figure
The reulted figure of the function.
"""
bandwidth = sm.nonparametric.bandwidths.select_bandwidth(
data_matrix["PC1"], "silverman", kernel=None
)
s, e = (
FFTKDE(kernel="gaussian", bw=(bandwidth_scale * bandwidth))
.fit(data_matrix["PC1"].to_numpy())
.evaluate()
)
fig = subplots.make_subplots(
rows=2,
cols=1,
row_heights=[0.15, 0.85],
vertical_spacing=0.02,
shared_yaxes=False,
shared_xaxes=True,
)
fig.add_trace(
go.Scatter(
x=s,
y=e,
mode="lines",
line=dict(color="royalblue", width=1.3),
name="PC1",
hovertemplate=None,
hoverinfo="skip",
showlegend=False,
),
row=1,
col=1,
)
fig.add_shape(
type="line",
yref="y",
xref="x",
xsizemode="scaled",
ysizemode="scaled",
x0=splitPoint,
x1=splitPoint,
y0=-0.005,
y1=e.max() * 1.2,
line=dict(color="red", width=1.5),
row=1,
col=1,
)
main_figure = px.scatter(
data_matrix,
x="PC1",
y="PC2",
color="cluster",
hover_name="cluster",
category_orders=category_order,
color_discrete_map=colList,
)["data"]
for i in main_figure:
fig.add_trace(i, row=2, col=1)
fig.update_traces(
mode="markers",
marker=dict(size=4),
hovertemplate=None,
hoverinfo="skip",
row=2,
col=1,
)
fig.add_shape(
type="line",
yref="y",
xref="x",
xsizemode="scaled",
ysizemode="scaled",
x0=splitPoint,
x1=splitPoint,
y0=data_matrix["PC2"].min() * 1.2,
y1=data_matrix["PC2"].max() * 1.2,
line=dict(color="red", width=1.5),
row=2,
col=1,
)
# reform visualization
fig.update_layout(
width=850,
height=700,
plot_bgcolor="#fff",
margin={"t": 20, "b": 50},
)
fig.update_xaxes(
fixedrange=True,
showgrid=True,
gridwidth=1,
gridcolor="#aaa",
zeroline=True,
zerolinewidth=1,
zerolinecolor="#aaa",
)
fig.update_yaxes(
fixedrange=True,
showgrid=True,
gridwidth=1,
gridcolor="#aaa",
zeroline=True,
zerolinewidth=1,
zerolinecolor="#aaa",
)
return fig
def _int_make_scatter_n_marginal_scatter(
data_matrix, splitPoint, category_order, colList, centers=None
):
"""
Create two plots that visualize the data on the second plot and on the
first give their presentation on the first principal component.
Parameters
----------
data_matrix : pandas.core.frame.DataFrame
The projection of the data on the first two Principal Components as
columns "PC1" and "PC2" and the final cluster each sample belong at
the end of the algorithm's execution as column "cluster".
splitPoint : int
The values of the point the data are split for this plot.
category_order : dict
The order of witch to show the clusters, contained in the
visualization, on the legend of the plot.
colList : dict
A dictionary containing the color of each cluster (key) as RGBa tuple
(value).
centers : numpy.ndarray
The values of the k-means' centers for the clustering of the data
projected on the first principal component for each split, for the
kM-PDDP algorithm.
Returns
-------
fig : plotly.graph_objs._figure.Figure
The reulted figure of the function.
"""
fig = subplots.make_subplots(
rows=2,
cols=1,
row_heights=[0.15, 0.85],
vertical_spacing=0.02,
shared_yaxes=False,
shared_xaxes=True,
)
# Create the marginal scatter plot of the figure (projection on one
# principal component)
marginal_figure = px.scatter(
data_matrix,
x="PC1",
y=np.zeros(data_matrix["PC1"].shape[0]),
color="cluster",
hover_name="cluster",
category_orders=category_order,
color_discrete_map=colList,
)["data"]
for i in marginal_figure:
fig.add_trace(i, row=1, col=1)
fig.update_traces(
mode="markers",
marker=dict(size=5),
hovertemplate=None,
hoverinfo="skip",
showlegend=False,
row=1,
col=1,
)
# If there are k-Means centers add them on the marginal scatter plot
if centers is not None:
fig.add_trace(
go.Scatter(
x=centers,
y=np.zeros(2),
mode="markers",
marker=dict(symbol=22, color="darkblue", size=15),
name="centers",
hovertemplate=None,
hoverinfo="skip",
showlegend=False,
),
row=1,
col=1,
)
fig.add_shape(
type="line",
yref="y",
xref="x",
xsizemode="scaled",
ysizemode="scaled",
x0=splitPoint,
x1=splitPoint,
y0=-2.5,
y1=2.5,
line=dict(color="red", width=1.5),
row=1,
col=1,
)
# Create the main scatter plot of the figure
main_figure = px.scatter(
data_matrix,
x="PC1",
y="PC2",
color="cluster",
hover_name="cluster",
category_orders=category_order,
color_discrete_map=colList,
)["data"]
for i in main_figure:
fig.add_trace(i, row=2, col=1)
fig.update_traces(
mode="markers",
marker=dict(size=4),
hovertemplate=None,
hoverinfo="skip",
row=2,
col=1,
)
fig.add_shape(
type="line",
yref="y",
xref="x",
xsizemode="scaled",
ysizemode="scaled",
x0=splitPoint,
x1=splitPoint,
y0=data_matrix["PC2"].min() * 1.2,
y1=data_matrix["PC2"].max() * 1.2,
line=dict(color="red", width=1.5),
row=2,
col=1,
)
# Reform visualization
fig.update_layout(
width=850, height=700, plot_bgcolor="#fff", margin={"t": 20, "b": 50}
)
fig.update_xaxes(
fixedrange=True,
showgrid=True,
gridwidth=1,
gridcolor="#aaa",
zeroline=True,
zerolinewidth=1,
zerolinecolor="#aaa",
)
fig.update_yaxes(
fixedrange=True,
showgrid=True,
gridwidth=1,
gridcolor="#aaa",
zeroline=True,
zerolinewidth=1,
zerolinecolor="#aaa",
)
return fig
# %% Algorithm re-execution tasks
def _recalculate_after_spchange(hipart_object, split, splitpoint_value):
"""
Given the serial number of the HiPart algorithm tree`s internal nodes and a
new split-point value recreate the results of the HiPart member algorithm
with the indicated change.
Parameters
----------
hipart_object : dePDDP or iPDDP or kM_PDDP or PDDP object
The object that we want to manipulate on the premiss of this function.
split : int
The serial number of the dePDDP tree`s internal nodes.
splitpoint_value : float
New split-point value.
Returns
-------
hipart_object : dePDDP or iPDDP or kM_PDDP or PDDP object
A dePDDP class type object, with complete results on the algorithm's
analysis
"""
tree = hipart_object.tree
# find the cluster nodes
clusters = tree.leaves()
clusters = sorted(clusters, key=lambda x: x.identifier)
# find the tree`s internal nodes a.k.a. splits
dictionary_of_nodes = tree.nodes
number_of_nodes = len(list(dictionary_of_nodes.keys()))
leaf_node_list = [j.identifier for j in clusters]
internal_nodes = [i for i in range(number_of_nodes) if not (i in leaf_node_list)]
# Insure that the root node will be always an internal node while is the
# only node in the tree
if len(internal_nodes) == 0:
internal_nodes = [0]
# remove all the splits (nodes) of the tree after the one bing manipulated.
# this process starts from the last to the first to ensure deletion
node_keys = list(dictionary_of_nodes.keys())
for i in range(len(node_keys) - 1, 0, -1):
if node_keys[i] > internal_nodes[split]:
if tree.get_node(node_keys[i]) is not None:
if not tree.get_node(internal_nodes[split]).is_root():
if (
tree.parent(internal_nodes[split]).identifier
!= tree.parent(node_keys[i]).identifier
):
tree.remove_node(node_keys[i])
else:
tree.remove_node(node_keys[i])
# change the split permission of all the internal nodes to True so the
# algorithm can execute correctly
dictionary_of_nodes = tree.nodes
for i in dictionary_of_nodes:
if dictionary_of_nodes[i].is_leaf():
if dictionary_of_nodes[i].data["split_criterion"] is not None:
dictionary_of_nodes[i].data["split_permission"] = True
# reset status variables for the code to execute
hipart_object.node_ids = len(list(dictionary_of_nodes.keys())) - 1
hipart_object.cluster_color = len(tree.leaves()) + 1
tree.get_node(internal_nodes[split]).data["splitpoint"] = splitpoint_value
# continue the algorithm`s execution from the point left
hipart_object.tree = tree
hipart_object.tree = _partial_predict(hipart_object)
return hipart_object
def _partial_predict(hipart_object):
"""
Execute the steps of the algorithm dePDDP until one of the two stopping
criterion is not true.
Parameters
----------
hipart_object : dePDDP or iPDDP or kM_PDDP or PDDP object
The object that we want to manipulate on the premiss of this function.
Returns
-------
tree : treelib.Tree
The newly created tree after the execution of the algorithm.
"""
tree = hipart_object.tree
found_clusters = len(tree.leaves())
selected_node = hipart_object.select_kid(tree.leaves())
while (selected_node is not None) and (
found_clusters < hipart_object.max_clusters_number
): # (ST1) or (ST2)
hipart_object.split_function(tree, selected_node) # step (1)
# select the next kid for split based on the local minimum density
selected_node = hipart_object.select_kid(
tree.leaves(), hipart_object.decreasing
) # step (2)
found_clusters = found_clusters + 1 # (ST1)
return tree