"""
Compute various information relating to d_3 invariants of tight contact
structures on lens spaces L(p,q), defined as -p/q surgery on the unknot.

This code accompanies the paper "Contact structures and reducible surgeries",
by Tye Lidman and Steven Sivek, arXiv:1410.0303.

AUTHORS:

- Steven Sivek (2014-10-01)

EXAMPLES::

   sage: lens_fraction(26,11)
   [-3, -2, -3, -2, -2]
   sage: min_d3(8,3)
   -1/4
   sage: lens_summands(12)
   [(13, 3), (15, 4), (19, 4)]

   sage: d3_values(36,5)
   {11/18, 7/36, -1/2, 3/4}
   sage: rotation_numbers(36,5)
   [-6, 6]
   sage: possible_tb(36,5)
   [-35]
"""

def lens_fraction(p,q):
  """
  Determine the continued fraction [a_1,...,a_n] for -p/q, where
  p > q > 0 and gcd(p,q)=1. The entries satisfy a_i <= -2 for all i.
  """
  answer = []
  frac = ZZ(-p)/ZZ(q)
  while True:
    ak = floor(frac)
    answer += [ak]
    if frac == ak:
      return answer
    frac = ZZ(1) / (ak - frac)

def linking_matrix(p,q):
  """
  Return the linking matrix of a chain of framed unknots
  on which surgery produces L(p,q), -p/q surgery on the unknot.
  """
  m = matrix.diagonal(lens_fraction(p,q))
  for i in range(m.nrows()-1):
    m[i,i+1] = 1
    m[i+1,i] = 1
  return m

def lminv(p,q):
  """
  Return the inverse of the linking matrix for L(p,q).
  """
  def __fraction_head_tail(p,q):
    ## If -p/q has continued fraction [a_1,...,a_n], return a pair of lists,
    ## one the sequence of numerators of [a_1,...,a_i] for all i and one of 
    ## the sequence of numerators of [a_j,...,a_n] for all j.
    ## Return the inverse of the linking matrix for L(p,q).
    avals = lens_fraction(p,q)
    n = len(avals)
    def fntail(a):
      fnt = [abs(a[n-1])]*n + [1]
      for j in range(n-2,-1,-1):
        fnt[j] = abs(a[j])*fnt[j+1] - fnt[j+2]
      return fnt
    fnumtail = fntail(avals)
    avals.reverse()
    fnumhead = fntail(avals)
    fnumhead.reverse()
    return fnumhead, fnumtail

  n = len(lens_fraction(p,q))
  fnumhead, fnumtail = __fraction_head_tail(p,q)
  m = matrix(QQ,n,n)
  for i in range(n):
    for j in range(i,n):
      m[i,j] = fnumhead[i] * fnumtail[j+1]
      m[j,i] = m[i,j]
  return (-1/p) * m

def min_d3(p,q):
  """
  Return the d_3 invariant of the canonical tight contact structure
  on L(p,q), which minimizes d_3 among all tight contact structures.
  """
  n = len(lens_fraction(p,q))
  return (-f_recurrence(p,q)/p + n - 2) / 4

def f_recurrence(p,q):
  """
  Compute f(p,q) = r^T (-pM^{-1}) r, which is equal to -p*c_1(X,J)^2 for
  the canonical Stein filling of the canonical tight contact structure on
  L(p,q), via its recurrence relation.
  """
  aa = lens_fraction(p,q)
  n = len(aa)
  pt, qt = abs(aa[n-1]), 1
  fval = (pt-2)**2
  for i in range(n-2,-1,-1):
    pt, qt = abs(aa[i])*pt-qt, pt
    fval = ((pt-qt-1)**2 + pt*fval)/qt
  return fval

def d3_values(p,q):
  """
  Return the set of values of d_3(xi), where xi ranges over
  all tight contact structures on L(p,q).
  """
  minv = lminv(p,q)
  rbounds = [abs(a+2) for a in lens_fraction(p,q)]
  n = minv.nrows()
  dvals = []
  rvals = [range(-r,r+1,2) for r in rbounds]
  for rlist in CartesianProduct(*rvals):
    v = vector(rlist)
    csq = v.row()*minv*v.column()
    dvals += [(csq[0,0]+n-2)/4]
  return Set(dvals)

def rotation_numbers(p,q):
  """
  Return a sorted list of the rotation numbers for L(p,q).
  """
  r2list = filter(is_square, [-p*(4*d+1) for d in d3_values(p,q)])
  if len(r2list)==0:
    return []
  rvals = []
  for s in r2list:
    rvals.append(sqrt(s))
    if s > 0:
      rvals.append(-sqrt(s))
  rvals.sort()
  return rvals

def possible_tb(p,q):
  """
  Determine all values of maxtb which could possibly produce reducible
  Legendrian surgeries with an L(p,q) summand.
  """
  rvals = rotation_numbers(p,q)
  if len(rvals) == 0:
    return []

  maxrun, currun = 1, 1
  for i in range(1,len(rvals)):
    if rvals[i] == rvals[i-1]+2:
      currun += 1
      if currun > maxrun:
        maxrun = currun
    else:
      currun = 1
  return range(-p+1, -p+maxrun+1)

def lens_summands(t):
  """
  Find all possible lens space summands L(p,q), q>1, of
  reducible Legendrian surgeries on a knot K with maxtb(K) = -t < 0.
  """
  ans = []
  maxp = (2*t-2) if (t%2==1) else (2*t-4) ## p < maxp
  for p in range(t+1, maxp):
    for q in range(2,p):
      if gcd(p,q)==1 and q <= ZZ(q).inverse_mod(p):
        if -t in possible_tb(p,q):
          ans.append((p,q))
  return ans
