(load "schubert.l")

(defun tensor-monomial (start num-vars monom)
  (let ((new (all-zero-monomial num-vars)))
    (dotimes (i (length monom) new)
      (setf (aref new (+ start i)) (aref monom i)))))

(defun tensor-poly (start num-vars poly)
  (mapcar
    #'(lambda (term)
        (cons (car term) (tensor-monomial start num-vars (cdr term))))
    poly))

; Compute the image of the basis in a tensor product of polynomial rings. In
; other words, it shifts the indices of the variables by start.
(defun tensor-basis (start num-vars basis)
  (mapcar #'(lambda (poly) (tensor-poly start num-vars poly)) basis))

; Compute a Groebner basis for the MIMO ring
; N: list of transmit dimensions
; M: list of receive dimensions
; d: message dimension (integer)
(defun mimo-groebner-basis (N M d)
  (let* ((num-users (length N))
         (num-vars (* 2 num-users d))
         (factor-bases nil)
         (basis nil))
    (dolist (dim (append N M))
      (unless (assoc dim factor-bases)
        (push (cons dim (schubert-groebner-basis d (- dim d))) factor-bases)))
    (do ((i 0 (1+ i))
         (N N (cdr N)))
        ((>= i num-users))
      (push (tensor-basis (* i d) num-vars (cdr (assoc (car N) factor-bases)))
            basis))
    (do ((i 0 (1+ i))
         (M M (cdr M)))
        ((>= i num-users))
      (push (tensor-basis (* (+ i num-users) d) num-vars
                          (cdr (assoc (car M) factor-bases)))
            basis))
    (apply #'nconc basis)))

;;; Computing the incidence class

; Given a partition part, with maximum value width, returns a partition with
; width parts, forming the complement of part in a box.
(defun complement-partition (part width)
  (let ((height (length part))
        (compl (make-array (list width))))
    (dotimes (i width compl)
      (do ((j 0 (1+ j)))
          ((or (>= j height) (> (aref part (- height j 1)) i))
            (setf (aref compl (- width i 1)) j))))))

(defun concat-vectors (v w)
  (let* ((lv (length v))
         (lw (length w))
         (result (make-array (list (+ lv lw)))))
    (dotimes (i lv) (setf (aref result i) (aref v i)))
    (dotimes (i lw) (setf (aref result (+ lv i)) (aref w i)))
    result))

; Takes two vector polynomials, computes their tensor product in the tensor
; product of the two polynomial rings.
(defun tensor-polys (p q)
  (mapcan
    #'(lambda (term1)
        (mapcar
          #'(lambda (term2)
              (cons (* (car term1) (car term2))
                    (concat-vectors (cdr term1) (cdr term2))))
          q))
    p))

(defun list-partitions (max-height max-width degree)
  (mapcar #'leading-part (list-schubert-monoms max-height max-width degree)))

; Compute the incidence class in Gr(d, N) x Gr(d, M) for the first vector space
; to be contained within the dual of the second.
; N and M must be at least as big as d
(defun mimo-incidence-class (N M d)
  (let ((terms nil))
    (dotimes (degree-t (1+ (* d d)))
      (let* ((degree-r (- (* d d) degree-t))
             (monoms-t (list-schubert-monoms d (- N d) degree-t))
             (monoms-r (list-schubert-monoms d (- M d) degree-r))
             (parts-t (mapcar #'leading-part monoms-t))
             (parts-r (mapcar #'leading-part monoms-r))
             (cov-t (transpose (invert-std-unipotent (cov-matrix d (- N d)
                                                                 degree-t))))
             (cov-r (transpose (invert-std-unipotent (cov-matrix d (- M d)
                                                                degree-r)))))
        (push
          (mapcan
            #'(lambda (part)
                (if (and (< M (* 2 d)) (< (aref part (- (* 2 d) M 1)) d))
                    nil
                  (tensor-polys
                    (coeff-vector
                      (nth (position part parts-t :test #'equalp) cov-t)
                      monoms-t)
                    (coeff-vector
                      (nth (position (complement-partition part d) parts-r
                                      :test #'equalp)
                           cov-r)
                      monoms-r))))
            (list-partitions d (min d (- N d)) degree-t))
           terms)))
    (simplify-poly (apply #'nconc terms))))

(defun tensor-split-monomial (start1 start2 num-vars monom)
  (let ((new (all-zero-monomial num-vars))
        (block-len (/ (length monom) 2)))
    (dotimes (i block-len new)
      (setf (aref new (+ start1 i)) (aref monom i))
      (setf (aref new (+ start2 i)) (aref monom (+ i block-len))))))

; Compute the image of the polynomial in a polynomial ring with num-var
; variables. The first half of the variables in the original ring are placed
; starting at start1, and the second half at start2.
(defun tensor-split-poly (start1 start2 num-vars poly)
  (mapcar
    #'(lambda (term)
        (cons (car term)
              (tensor-split-monomial start1 start2 num-vars (cdr term))))
    poly))

; Compute the MIMO incidence class for the ith transmitter and jth receiver
(defun mimo-class-factor (N M d num-users i j)
  (tensor-split-poly (* i d) (* (+ j num-users) d) (* 2 num-users d)
                     (mimo-incidence-class (nth i N) (nth j M) d)))

; Compute the factors of the MIMO class
(defun mimo-class-factors (N M d)
  (let* ((num-users (length N))
         (factors nil))
    (dotimes (i num-users)
      (do ((j (1- num-users) (1- j)))
          ((< j 0))
        (unless (= i j)
          (push (mimo-class-factor N M d num-users i j) factors))))
    factors))

; Compute the codimension of monomial within the factors indexed by blocks.
(defun partial-codim (blocks d monomial)
  (let ((codim 0))
    (dolist (i blocks codim)
      (dotimes (j d)
        (incf codim (* (1+ j) (aref monomial (+ (* i d) j))))))))

; Filter a tree to remove those monomials whose codimension in the blocks
; indexed by blocks is above a threshold. If we know we're not going to multiply
; by any more terms using these variables and that the final answer will have
; codimension min-codim in these variables, then the eliminated terms won't make
; any difference.
(defun filter-codim (blocks d num-vars min-codim tree)
  (if (<= min-codim 0)
      tree
    (let ((new-tree (zero-tree)))
      (iterate-tree tree num-vars
        (lambda (coeff monomial)
          (when (>= (partial-codim blocks d monomial) min-codim)
            (insert-into-tree coeff monomial new-tree))))
      new-tree)))

(defun mimo-product-nofilter (N M d fast-reduce-fn slow-reduce-fn)
  (let* ((num-users (length N))
         (num-vars (* 2 d num-users))
         (tree (one-tree num-vars)))
    (do ((i 0 (1+ i)))
        ((>= i num-users) tree)
      (do ((j (1- num-users) (1- j)))
          ((< j 0))
        (unless (= i j)
          (setf tree (multiply-reduce (mimo-class-factor N M d num-users i j)
                                      tree fast-reduce-fn slow-reduce-fn)))))))

(defun mimo-product (N M d fast-reduce-fn slow-reduce-fn)
  (let* ((num-users (length N))
         (num-vars (* 2 d num-users))
            ; This is the minimum codimension of the cycles in the blocks of
            ; variables indexed by filter-blocks. The initial value is the
            ; negative o fthe dimension of the result cycle.
         (min-codim  (* -1 d (+ (reduce #'+ N) (reduce #'+ M)
                                (* -1 num-users (1+ num-users) d))))
         (tree (one-tree num-vars))
         (filter-blocks nil)) ; blocks of variables which we can filter on
    (do ((i 0 (1+ i)))
        ((>= i num-users) tree)
      (setf tree (filter-codim filter-blocks d num-vars min-codim tree))
      (when t (format t "Filtering~%"))
      (do ((j (1- num-users) (1- j)))
          ((< j 0))
        (unless (= i j)
          (setf tree (multiply-reduce (mimo-class-factor N M d num-users i j)
                                       tree fast-reduce-fn slow-reduce-fn))
          (when t (format t "Multiplication step~%"))
          (when (or (and (= j (1- num-users)) (= i (1- j)))
                    (and (= i (1- num-users)) (> j 0)))
            ; In these cases, we've just processed our last factor with this
            ; value of j. Therefore, we can threshold on that block of
            ; variables.
            (push (+ num-users j) filter-blocks)
            (incf min-codim (* d (- (nth j M) d)))
            (setf tree (filter-codim filter-blocks d num-vars min-codim tree))
            (when t (format t "Extra filtering~%")))))
      (push i filter-blocks)
      (incf min-codim (* d (- (nth i N) d))))))

; Compute the MIMO class. N and M are lists of the same length, indicating the
; dimensions of the transmit and receive spaces, and d is the dimension of the
; signal. 
(defun mimo-class (N M d &optional (filter t))
  (let ((basis (mimo-groebner-basis N M d)))
    (funcall (if filter #'mimo-product #'mimo-product-nofilter) N M d
             (make-fast-reduce-fn basis)
             (make-slow-reduce-fn basis))))

(defun repeat (val times)
  (if (<= times 0)
      nil
    (cons val (repeat val (1- times)))))

(defun mimo-class-time (N M d)
  (let* ((basis (mimo-groebner-basis N M d))
         (fast-reduce-fn (time (make-fast-reduce-fn basis)))
         (slow-reduce-fn (time (make-slow-reduce-fn basis))))
    (time (mimo-product N M d fast-reduce-fn slow-reduce-fn))))
