"""
.. inheritance-diagram:: pyopus.evaluator.distrib
    :parts: 1

Random variable probability density distributions defined as transformations 
of a random variable with standard normal distribution N(0,1) and (standard) 
uniform distribution on [0,1]. 
"""

from scipy.special import erf, erfinv
from .. import PyOpusError
from .transform import *

__all__ = [ "Distribution", "Normal", "Uniform", "UnifAbsTol", "UnifRelTol", "distList" ]

		
class Distribution(TransformBase):
	"""
	Base class for random variable distributions. 
	"""
	
	def __init__(self):
		TransformBase.__init__(self)
		pass
	
	def fromSND(self, x):
		"""
		Transform from standard normal distribution N(0,1). 
		"""
		pass
	
	def toSND(self, x):
		"""
		Transform to standard normal distribution N(0,1).
		"""
		pass
	
	def fromSUD(self, x):
		"""
		Transform from uniform distribution on (0,1). 
		"""
		pass
	
	def toSUD(self, x):
		"""
		Transform to uniform distribution on (0,1). 
		"""
		pass
	
	def CDF(self, x):
		"""
		Cumulative distribution function. 
		"""
		return self.toSUD(x)
	
	def mean(self):
		"""
		Returns the mean of the distribution. 
		"""
		return self.fromSND(0)
	
	def bounds(self):
		"""
		Returns the lower and the upper bound of the distribution. 
		If a bound is not finite, returns ``None``. 
		"""
		return None, None
		
	

class Normal(Distribution):
	"""
	Normal distribution with given *mean* and standard deviation *sigma*. 
	"""
	
	def __init__(self, mean, sigma):
		self.m=mean
		self.s=sigma
		if sigma<0:
			raise PyOpusError("Standard deviation of normal distribution must not be negative.")
		
	def fromSND(self, x):
		"""
		Convert from standard normal distribution to N(mean, sigma). 
		"""
		return x*self.s+self.m
	
	def toSND(self, x):
		"""
		Convert from N(mean, sigma) to standard normal distribution. 
		"""
		return (x-self.m)/self.s
	
	def fromSUD(self, x):
		"""
		Convert from uniform distribution over (0,1) to N(mean, sigma). 
		"""
		return 2**0.5*self.s*erfinv(2*x-1.0)+self.m
		
	def toSUD(self, x):
		"""
		Convert from N(mean, sigma) to uniform distribution over (0,1). 
		"""
		return 0.5*(1+erf((x-self.m)/(self.s*2**0.5)))
	

class Uniform(Distribution):
	"""
	Uniform distribution over (a,b). 
	"""
	
	def __init__(self, a=0, b=1):
		self.a=a
		self.b=b 
		if a>=b:
			raise PyOpusError("Lower bound of uniform distribution must be smaller than upper bound.")
	
	def fromSUD(self, x):
		"""
		Convert from uniform distribution over (0,1) to unform 
		distribution over (a,b). 
		"""
		return x*(self.b-self.a)+self.a
	
	def toSUD(self, x):
		"""
		Convert from uniform distribution over (a,b) to unform 
		distribution over (0,1). 
		"""
		return (x-self.a)/(self.b-self.a)
		
	def fromSND(self, x):
		"""
		Convert from standard normal distribution to uniform 
		distribution over (a,b)
		"""
		return 0.5*(1+erf(x/2**0.5))*(self.b-self.a)+self.a
	
	def toSND(self, x):
		"""
		Convert from uniform distribution over (a,b) to standard 
		normal distribution. 
		"""
		return 2**0.5*erfinv(2*((x-self.a)/(self.b-self.a))-1.0)
	
	def bounds(self):
		"""
		Returns the lower and the upper bound of the distribution. 
		"""
		return a, b 


class UnifAbsTol(Uniform):
	"""
	Uniform distribution over (center-tol, center+tol). 
	"""
	
	def __init__(self, center, tol):
		self.center=center
		self.tol=tol 
		if tol<=0:
			raise PyOpusError("Tolerance must be greater than zero.")
		Uniform.__init__(self, center-tol, center+tol)


class UnifRelTol(Uniform):
	"""
	Uniform distribution over (center*(1-tol), center*(1+tol)). 
	"""
	
	def __init__(self, center, tol):
		self.center=center
		self.tol=tol 
		if tol<=0:
			raise PyOpusError("Tolerance must be greater than zero.")
		Uniform.__init__(self, center*(1-tol), center*(1+tol))

distList=[
	Normal, 
	Uniform, 
	UnifAbsTol, 
	UnifRelTol, 	
]

if __name__=='__main__':
	from numpy.random import uniform
	for d, rng in [
		[Normal(5,10), [-20, 20]], 
		[Uniform(5,10), [6.0, 9.0]], 
	]:
		for ii in range(10):
			x=uniform(rng[0], rng[1])
			y1=d.fromSND(d.toSND(x))
			y2=d.fromSUD(d.toSUD(x))
			print(x-y1, x-y2)
	
