Source code for onegpy.solutions.node

import copy


[docs]class Node(object): """Node object. # Attribute id: int, symbolic id for Function. This id is not assigned for each node object. children: list of Node object, children of this node. """
[docs] def __init__(self, func_id=-1): """ :param func_id: int. function id """ self.func_id = func_id self.children = []
@property def is_terminal(self): return not self.children
[docs]class Function(object): """Function object. # Attribute n_children: int, the number of children of this function. eval: function, function of node to evaluate. """
[docs] def __init__(self, n_children, f_eval=None): """ :param n_children: int. the number of children of this function. :param f_eval: function, function of node to evaluate. """ self.n_children = n_children self.f_eval = f_eval
def __call__(self, x): """ evaluate x :param x: ndarray(shape = (#data, dim)). data to evaluate :return: evaluated value """ if self.n_children > 0: n_children_checker(self.n_children, len(x)) return self.f_eval(x)
[docs]def set_id(node, func_id): """ Function for setting id to node object. :param node: Node object. target node. :param func_id: int. id of node. :return: """ node_checker(node) func_id_checker(func_id) node.func_id = func_id
[docs]def get_n_children(func_id, function_dict): """ Get the number of children of function that id = func_id. :param func_id: int. function id. :param function_dict: Function dict object. :return: int. the number of children. """ func_id_checker(func_id) func = function_dict[func_id] return func.n_children
[docs]def set_children(node, children): """ Setter for children to node :param node: node object. target node. :param children: list of node. children to set to target node. """ node_checker(node) children_checker(children) node.children = children
[docs]def copy_node(node, deep=False): """ Copy function :param node: node object, target node to copy. :param deep: bool. deep copy or not :return: node object. copy node. """ if not deep: new_node = node.__class__(node.func_id) if node.children: new_node.children = [copy.copy(c) for c in node.children] return new_node else: return copy.deepcopy(node)
[docs]def copy_nodes_along_graph(graph): """ Copy node object from ``root`` to the target node based on ``graph'' This method differs from deepcopy(root) in that it copies only the nodes along ``graph''. :param graph: list of ``(i, Node object)'' where ``i'' is the index of the next node of graph and ``Node object'' is the parent node. This graph is obtained by using ``get_parent_node''. :return: the index of target node node in the parent, copied the target node object and copied root object. """ previous_pos = None current_node = None root = None for pos, node in graph: copied_node = copy_node(node) if previous_pos is None: root = copied_node else: current_node.children[previous_pos] = copied_node current_node = copied_node previous_pos = pos return previous_pos, current_node, root
[docs]def get_parent_node(root, target_node): """function for searching parent node of target node. :param root: Node object, root node. :param target_node: Node object, target node. :return Position of target_node in parent node, node object of parent node """ def find_parent_node(current_node): if current_node.is_terminal: return children = current_node.children p = None for i, c in enumerate(children): if c is target_node: return i, current_node else: p = p or find_parent_node(c) if p is not None: break return p nodes_checker([root, target_node]) if target_node is root: msg = 'There is no parent of root.' raise ValueError(msg) pos, parent = find_parent_node(root) or (None, None) if pos is None or parent is None: msg = 'Invalid arguments: cannot find parent.' raise ValueError(msg) return pos, parent
[docs]def get_graph_to_target(root, target_node): """function for searching a graph from root node to a target node. :param root: Node object, root node. :param target_node: Node object, target node. :return Graph from ``root'' to ``target_node'' """ graph = [] def find_parent_node(current_node): nonlocal graph if current_node.is_terminal: return children = current_node.children p = None for i, c in enumerate(children): graph.append((i, current_node)) if c is target_node: return i, current_node p = p or find_parent_node(c) if p is None: graph.pop() else: break return p nodes_checker([root, target_node]) if target_node is root: msg = 'There is no parent of root.' raise ValueError(msg) pos, parent = find_parent_node(root) or (None, None) if pos is None or parent is None: msg = 'Invalid arguments: cannot find parent.' raise ValueError(msg) return graph
[docs]def get_all_node(root): """ function for getting all node in the solution :param root: Node object. root node of target solution. :return: list of Node object. All node in the solution """ node_checker(root) nodes = [root] def add_children_to_nodes(current_node): children = current_node.children nonlocal nodes if current_node.is_terminal: return for c in children: nodes.append(c) add_children_to_nodes(c) add_children_to_nodes(root) return nodes
[docs]def calc_node_depth(node): """ Calculate the depth of nodes which is following under the target node. :param node: node object. target node. :return: int. calculated depth. """ d_list = [] def cal_depth(c_node, depth): if not c_node.is_terminal: for c in c_node.children: cal_depth(c, depth+1) else: d_list.append(depth) cal_depth(node, 0) return max(d_list)
[docs]def get_all_terminal_nodes(root): """ function for getting all terminal node in the solution :param root: Node object. root node of target solution. :return: list of Node object. All terminal node in the solution """ node_checker(root) terminal_nodes = [] def add_children_to_nodes(current_node): children = current_node.children nonlocal terminal_nodes if current_node.is_terminal: terminal_nodes.append(current_node) return for c in children: add_children_to_nodes(c) add_children_to_nodes(root) return terminal_nodes
[docs]def get_all_nonterminal_nodes(root): """ function for getting all terminal node in the solution :param root: Node object. root node of target solution. :return: list of Node object. All terminal node in the solution """ node_checker(root) nonterminal_nodes = [] def add_children_to_nodes(current_node): children = current_node.children nonlocal nonterminal_nodes if not current_node.is_terminal: nonterminal_nodes.append(current_node) for c in children: add_children_to_nodes(c) add_children_to_nodes(root) return nonterminal_nodes
[docs]def get_all_terminal_points(root): """ function for getting all terminal points in the solution this function is for crossover in MLPS-GP :param root: Node object. root node of target solution. :return: list of tuple(Node, int). (parent, index of terminal node) """ node_checker(root) points = [] def add_children_to_nodes(current_node): children = current_node.children nonlocal points if current_node.is_terminal: return for index, c in enumerate(children): if c.is_terminal: points.append((current_node, index)) add_children_to_nodes(c) add_children_to_nodes(root) return points
[docs]def node_equal(node_a, node_b, as_tree=False): """ Function for comparing two nodes. :param node_a: Node object :param node_b: Node object :param as_tree: If False, compare the nodes based the node's type and the function id, otherwise based on their tree structures as well as the node's type and the function id. :return: bool """ def func_id_equal(x, y): node_type_equal = not(bool(x.children) ^ bool(y.children)) if x.func_id == y.func_id and node_type_equal: return True else: return False nodes_checker([node_a, node_b]) if not as_tree: return func_id_equal(node_a, node_b) else: for x, y in zip(get_all_node(node_a), get_all_node(node_b)): if not func_id_equal(x, y): return False return True
[docs]def node_array_equal(nodes_a, nodes_b): for node_a, node_b in zip(nodes_a, nodes_b): if not node_equal(node_a, node_b, as_tree=False): return False return True
[docs]def func_id_checker(func_id): if not isinstance(func_id, int): typ = TypeError msg = 'Expected type: {} not {}.'.format(int, type(func_id)) elif func_id < 0: typ = ValueError msg = 'Expected condition: func_id >= 0.' else: return raise typ(msg)
[docs]def node_checker(node): if not isinstance(node, Node): typ = TypeError msg = 'Expected type: {} not {}.'.format(Node, type(node)) else: return raise typ(msg)
[docs]def nodes_checker(nodes): for node in nodes: node_checker(node)
[docs]def children_checker(children): if not isinstance(children, list): raise TypeError('Expected type: {}'.format(list)) nodes_checker(children)
[docs]def n_children_checker(n_children, len_x): if n_children != len_x: raise ValueError('expect the n_children == len_x, but got n_children{} != len_x' .format(n_children, len_x))