""" 
.. inheritance-diagram:: pyopus.optimizer.vptree
    :parts: 1

**Vantage Point-tree (VP-tree) implementation (PyOPUS subsystem name: VPT)**

Based on VPTree from PyPI. 
"""
from ..misc.debug import DbgMsgOut, DbgMsg

import numpy as np
import bisect

__all__ = [ "VPTree" ]

class VPNode:
    def __init__(self, depth=0):
        self.npts=0
        self.ndepth=depth
        self.vp=None
        self.left=None
        self.right=None
        self.left_min = np.inf
        self.left_max = 0
        self.right_min = np.inf
        self.right_max = 0
        
    def build(self, points, dist_fn, unique=True, debug=False):
        max_depth=self.ndepth
        
        nDup=0
        
        # At node
        tasks=[(self, points)]
        
        while len(tasks)>0:
            node, points = tasks.pop()
            
            # Update max depth
            if node.ndepth>max_depth:
                max_depth=node.ndepth
        
            # Choose VP
            vp=points[0]
            points=points[1:]
            
            if len(points)>0:
                # Split points set
                distances = [dist_fn(vp, p) for p in points]
            
                median = np.median(distances)
            
                left_points = []
                right_points = []
                left_min = np.inf
                left_max = 0
                right_min = np.inf
                right_max = 0
                for point, distance in zip(points, distances):
                    # This will put two points with the same distance from vp into the same subset
                    # Eventually one of them will become the vp while the other will have zero distance from the vp. 
                    # At that point we detect a duplicate. 
                    if unique and distance==0.0:
                        # Duplicate detected, throw it out
                        nDup+=1
                        continue
                    
                    if distance >= median:
                        # Update min and max distance from vantage point
                        right_min = min(distance, right_min)
                        if distance > right_max:
                            right_max = distance
                            right_points.insert(0, point) # put furthest first
                        else:
                            # If point is not the furthest point, add it to the end of the list
                            right_points.append(point)
                    else:
                        ## Update min and max distance from vantage point
                        left_min = min(distance, left_min)
                        if distance > left_max:
                            left_max = distance
                            left_points.insert(0, point) # put furthest first
                        else:
                            # If point is not the furthest point, add it to the end of the list
                            left_points.append(point)
                    
                    # Store ranges
                    node.left_min=left_min
                    node.left_max=left_max
                    node.right_min=right_min
                    node.right_max=right_max
                
                # Store vp and points count
                node.vp=vp
                node.npts=len(left_points)+len(right_points)+1
                
                # Create subnodes, add to tasks
                if len(left_points)>0:
                    node.left=VPNode(depth=node.ndepth+1)
                    tasks.append((node.left, left_points))
                    #print("left", len(left_points))
                
                if len(right_points)>0:
                    node.right=VPNode(depth=node.ndepth+1)
                    tasks.append((node.right, right_points))
                    #print("right", len(right_points))
                
            else:
                # Store vp and points count
                node.vp=vp
                node.npts=1
                #print("leaf", 1)
                
                
        return max_depth, nDup
    
    def append(self, point, dist_fn, treeDepth, unique=True, debug=False):
        # Compute tree capacity and get usage
        nc=2**(treeDepth-self.ndepth+1)-1
        nu=self.npts
        
        # Find node seqeuence from root to insertion point
        if self.vp is not None:
            # Tree not empty
            nhist=[]
            node=self
            while True:
                # Check capacity
                capacity=2**(treeDepth-node.ndepth+1)-1
                if debug:
                    print("Depth", node.ndepth, "capacity=",capacity, "new points=", node.npts+1)
                if node.npts+1>capacity:
                    # No space below this node, stop search
                    break
                
                d=dist_fn(point, node.vp)
                if d-node.left_max<=node.right_min-d:
                    # If p lies closer to the inner belt than to the outer belt, choose inner tree
                    # Update belt boundaries
                    if d>node.left_max:
                        node.left_max=d
                    if d<node.left_min:
                        node.left_min=d
                    
                    nhist.append(node)
                    
                    if node.left is None:
                        break
                    
                    node=node.left
                    if debug:
                        print("Going left")
                    
                else:
                    # Choose outer tree (right). Points in the middle are added to the outer tree. 
                    # Update belt boundaries
                    if d>node.right_max:
                        node.right_max=d
                    if d<node.right_min:
                        node.right_min=d
                    
                    nhist.append(node)
                    
                    if node.right is None:
                        break
                    
                    node=node.right
                    if debug:
                        print("Going right")
                
                # If we reach a leaf at tree depth, we should not rebuild that leaf, 
                # because this will increase the tree depth even before the tree is full
                # print(nu, nc)
                if node.ndepth==treeDepth:
                    if nu<nc/4:
                        if debug:
                            print("Reached tree depth, stopping search")
                        break
                    else:
                        if debug:
                            print("Reached tree depth, expanding tree because it is too full *****")
                        nhist.append(node)
                        break
                    
                 #   break
                
        else:
            # Empty tree
            if debug:
                print("Empty tree")
            nhist=[self]
        
        if debug:
            print("Have history with length", len(nhist))
        
        # Now we have a path to insertion point
        
        # Collect points, rebuild subtree
        if len(nhist)==0:
            # No insertion point (tree full)
            if debug:
                print("rebuild tree with height=", treeDepth-self.depth, "at depth=", self.depth)
            points=self.points()
            # points.append(point)
            points.insert(0, point)
            # Rebuild root node
            depth, nDup = self.build(points, dist_fn, unique=unique)
            if debug:
                print("reached depth", depth)
        else:
            # Collect from insNode on
            insNode=nhist[-1]
            if debug:
                print("rebuild tree with height=", treeDepth-insNode.depth, "at depth=", insNode.depth)
            points=insNode.points()
            # points.append(point)
            points.insert(0, point)
            # Rebuild from insNode
            depth, nDup = insNode.build(points, dist_fn, unique)
            if debug:
                print("reached depth", depth)
            # This was a partial rebuild, update node counts
            if nDup==0:
                for node in nhist:
                    if node is insNode:
                        break
                    node.npts+=1
                    
        return depth, nDup
            
    def points(self):
        l=list()
        nodes_to_visit=[self]
        while len(nodes_to_visit)>0:
            node=nodes_to_visit.pop()
            if node.vp is not None:
                l.append(node.vp)
            if node.left is not None:
                nodes_to_visit.append(node.left)
            if node.right is not None:
                nodes_to_visit.append(node.right)
        
        return l
    
    def getNNearest(self, query, n, dist_fn):
        if not isinstance(n, int) or n < 1:
            raise ValueError(DbgMsg("VPT", 'n must be strictly positive integer'))
        
        # List of closest neighbours
        neighbours = _AutoSortingList(max_size=n)
        
        if self.npts==0:
            return []
        
        # (node, distance from query point to min/max belt)
        # Use 0 for this distance so that this node gets examined no matter where the query point lies
        nodes_to_visit = [(self, 0)]

        # Distance of furthest neighbour from query point
        furthest_d = np.inf
        
        # While we have nodes to visit
        while len(nodes_to_visit) > 0:
            node, d0 = nodes_to_visit.pop(0)
            
            # Skip this node if distance from min/max belt is greater than furthest_d
            if d0 > furthest_d:
                continue
            
            # Compute distance of node's VP to query point
            d = dist_fn(query, node.vp)
            # If VP is within furtest_d, add it to the list of neighbours, update furthest_d
            if d < furthest_d:
                neighbours.append((d, node.vp))
                # Get last neighbour (the one that is furthest away), set furthest_d to its distance
                furthest_d, _ = neighbours[-1]
            
            # Do we have a left tree
            if node.left is not None:
                if node.left_min <= d <= node.left_max:
                    # Ball B(query, furthest_d) is centered within the inner min/max belt
                    # Add left tree to search, belt distance is 0
                    nodes_to_visit.insert(0, (node.left, 0))
                elif node.left_min - furthest_d <= d <= node.left_max + furthest_d:
                    # Ball B(query, furtest_d) intersects the inner min/max belt
                    # Add left tree to search with belt distance
                    #   min-d if d<min (query point inside B(vp, min)); this is always positive
                    #   d-max if d>min (query point outside B(vp, min) or on its surface); this distance can also be negative
                    nodes_to_visit.append((node.left,  node.left_min - d if d < node.left_min else d - node.left_max))
            
            # Do we have a right tree
            if node.right is not None:
                if node.right_min <= d <= node.right_max:
                    # Ball B(query, furtest_d) is centered within the outer min/max belt
                    # Add right tree to search, belt distance is 0
                    nodes_to_visit.insert(0, (node.right, 0))
                elif node.right_min - furthest_d <= d <= node.right_max + furthest_d:
                    # Ball B(query, furtest_d) intersects the outer min/max belt
                    # Add right tree to search with belt distance
                    #   min-d if d<min (query point inside B(vp, min)); this is always positive
                    #   d-max if d>min (query point outside B(vp, min) or on its surface); this distance can also be negative
                    nodes_to_visit.append((node.right, node.right_min - d if d < node.right_min else d - node.right_max))

        return list(neighbours)
    
    def findWithin(self, query, max_distance, dist_fn):
        neighbours = list()
        # (node, distance from query point to min/max belt)
        # Use 0 for this distance so that this node gets examined no matter where the query point lies
        nodes_to_visit = [(self, 0)]
        
        # While we have nodes to visit
        while len(nodes_to_visit) > 0:
            node, d0 = nodes_to_visit.pop(0)
            # Skip this node if distance from min/max belt is greater than max_distance
            if d0 > max_distance:
                continue
            
            # Compute distance of node's VP to query point
            d = dist_fn(query, node.vp)
            # If VP is within max_distance, add it to the list of neighbours
            if d < max_distance:
                neighbours.append((d, node.vp))

            # Do we have a left tree
            if node.left is not None:
                if node.left_min <= d <= node.left_max:
                    # Ball B(query, max_distance) is centered within the inner min/max belt
                    # Add left tree to search, belt distance is 0
                    nodes_to_visit.insert(0, (node.left, 0))
                elif node.left_min - max_distance <= d <= node.left_max + max_distance:
                    # Ball B(query, max_distance) intersects the inner min/max belt
                    # Add left tree to search with belt distance
                    #   min-d if d<min (query point inside B(vp, min)); this is always positive
                    #   d-max if d>min (query point outside B(vp, min) or on its surface); this distance can also be negative
                    nodes_to_visit.append(
                        (node.left, node.left_min - d if d < node.left_min else d - node.left_max))
            
            # Do we have a right tree
            if node.right is not None:
                if node.right_min <= d <= node.right_max:
                    # Ball B(query, max_distance) is centered within the outer min/max belt
                    # Add right tree to search, belt distance is 0
                    nodes_to_visit.insert(0, (node.right, 0))
                elif node.right_min - max_distance <= d <= node.right_max + max_distance:
                    # Ball B(query, max_distance) intersects the outer min/max belt
                    # Add right tree to search with belt distance
                    #   min-d if d<min (query point inside B(vp, min)); this is always positive
                    #   d-max if d>min (query point outside B(vp, min) or on its surface); this distance can also be negative
                    nodes_to_visit.append((node.right, node.right_min - d if d < node.right_min else d - node.right_max))

        return neighbours
    
    

class VPTree:
    """ 
    VP-Tree data structure for efficient nearest neighbour search.

    The VP-tree is a data structure for efficient nearest neighbour
    searching and finds the nearest neighbour in O(log n)
    complexity given a tree constructed of n data points. Construction
    complexity is O(n log n).

    Parameters
    ----------
    points : Iterable
        Construction points.
    dist_fn : Callable
        Function taking point instances as arguments and returning
        the distance between them.
    auto_rebuild_depth_factor : int
        Amount of excess tree depth generated by calls to :meth:`append` 
        that results in a tree rebuild. 
    unique : boolean
        Require points in the tree to be unique
    root : boolean (internal argument)
        When ``True`` the tree is built from supplied points, 
        otherwirse a blank root node is created
    """

    def __init__(self, points, dist_fn, unique=True):
        self.dist_fn=dist_fn
        self.unique=unique
        self.tdepth=0
        
        self.tree=VPNode(depth=0)
        if len(points)>0:
            self.tdepth,_=self.tree.build(points, self.dist_fn, self.unique)
    
    def append(self, point):
        depth, nDup = self.tree.append(point, self.dist_fn, self.tdepth, self.unique)
        if depth>self.tdepth:
            self.tdepth=depth
        
        return depth, nDup
        
    def points(self):
        return self.tree.points()
    
    def getNNearest(self, query, n):
        return self.tree.getNNearest(query, n, self.dist_fn)
    
    def getNearest(self, query):
        ret=self.tree.getNNearest(query, 1, self.dist_fn)
        if len(ret)==0:
            return None
        else:
            return ret[0]
    
    def findWithin(self, query, max_distance):
        return self.tree.findWithin(query, max_distance, self.dist_fn)
    
    def depth(self):
        return self.tdepth
    
    def capacity(self):
        return 2**(self.tdepth+1)-1
    
    def nPoints(self):
        return self.tree.npts


class _AutoSortingList(list):
    """ 
    Simple auto-sorting list (lexically ascending). 
    
    Parameters
    ---------
    max_size : int, optional
        Max queue size.
    """

    def __init__(self, max_size=None, *args):
        super(_AutoSortingList, self).__init__(*args)
        self.max_size = max_size

    def append(self, item):
        """ Append `item` and sort.

        Parameters
        ----------
        item : Any
            Input item.
        """
        bisect.insort_right(self, item)
        if len(self)>self.max_size:
            self.pop()
        

if __name__=="__main__":
    import time
    from pprint import pprint
    import cProfile
    
    def main():
        # Euclidean distance
        def euclidean(p, q):
            return ((p-q)**2).sum()**0.5
        
        # Slow search
        def slow_get_all_in_range(points, q, r, dist_fn):
            results=list()
            for point in points:
                d=dist_fn(q, point)
                if d<r:
                    results.append((d, point))
        
            return results

        # Dimension
        n=5
        
        # Point count
        N=31000
        
        # Seed
        np.random.seed(0)
        
        # Generate random points
        points=np.random.rand(N,n)
        
        ## Create VP tree
        #print("Construct from scratch")
        #t0=time.perf_counter()
        #tree=VPTree(points, euclidean)
        #dt=time.perf_counter()-t0
        #print("Construction time:", dt)
        #print("Depth:", tree.treeDepth(), "Count:", tree.nPoints(), "\n")
        
        # # Rebuild
        # print("Rebuild")
        # t0=time.perf_counter()
        # tree.rebuild()
        # dt=time.perf_counter()-t0
        # print("Rebuild time:", dt)
        # print("Depth:", tree.depth(), "Count:", tree.npoints(), "\n")
         
        # Build with appending
        print("Build by appending")
        t0=time.perf_counter()
        # Blank tree
        tree=VPTree([], euclidean)
        # Append points
        ii=1
        tt=[]
        for p in points:
            print("Inserting", ii)
            tree.append(p)
            tt.append(time.perf_counter()-t0)
            print("Tree depth=", tree.depth(), "capacity=", tree.capacity(), "\n")
            ii+=1
        dt=time.perf_counter()-t0
        print("Append build time:", dt)
        print("Depth:", tree.depth(), "Count:", tree.nPoints(), "Capacity:", tree.capacity(), "\n")
        
        import matplotlib.pyplot as plt
        plt.plot(tt)
        plt.show()
        1/0
        
        # Try appending a duplicate
        # print("Append a duplicate")
        # n1=tree.npoints()
        # ret=tree.append(points[0])
        # print("Initial points", n1, ", new points", tree.npoints())
        # print("Found entry:", ret, "\n")
        # 
        # # Try appending a new point
        # print("Append a new point")
        # n1=tree.npoints()
        # ret=tree.append(np.ones(n))
        # print("Initial points", n1, ", new points", tree.npoints())
        # print("Found entry:", ret, "\n")
        
        # Find all points within dist of origin
        print("Find within radius")
        q=np.zeros(n)
        dist=0.3
        t0=time.perf_counter()
        found=tree.findWithin(q, dist)
        dt=time.perf_counter()-t0
        print(f"Found {len(found)} points within {dist:f} of origin")
        print("Search time:", dt, "\n")
        
        # Sort according to distance, vector componets (lexically)
        found.sort()
        
        # Build result matrix
        V1=np.array([ p for _, p in found ])
        
        # Slow search
        # Find all points within 0.1 from origin
        print("Find within radius (brute force)")
        q=np.zeros(n)
        t0=time.perf_counter()
        found=slow_get_all_in_range(points, q, dist, euclidean)
        dt=time.perf_counter()-t0
        print(f"Found {len(found)} points within {dist:f} of origin")
        print("Search time:", dt)
        
        # Sort according to distance, vector componets (lexically)
        found.sort()
        
        # Build result matrix
        V2=np.array([ p for _, p in found ])
        
        # Print difference
        print("Point differences between tree and brute force search")
        pprint(V1-V2)
        print()
        
        1/0
        
        # # Find closest 10 points to the origin
        # print("Find within radius")
        # q=np.zeros(n)
        # t0=time.perf_counter()
        # found=tree.get_n_nearest_neighbours(q, 10)
        # dt=time.perf_counter()-t0
        # print(f"Found {len(found)} points closest to the origin")
        # print("Search time:", dt, "\n")
        # pprint(found)
        
        # Sort according to distance, vector componets (lexically)
        found.sort()
    
    # cProfile.run('main()')
    main()
    
