(in-package :computed-class)

(define-computed-universe compute-as)

(defun row-element-index (group-index element-index)
  (+ element-index (* 9 group-index)))

(defun row-group-index (solution-index)
  (floor (/ solution-index 9)))

(defun column-element-index (group-index element-index)
  (+ group-index (* element-index 9)))

(defun column-group-index (solution-index)
  (mod solution-index 9))

(defun block-element-index (group-index element-index)
  (+ (* 27 (floor (/ group-index 3)))
     (* 3 (mod group-index 3))
     (* 9 (floor (/ element-index 3))) (mod element-index 3)))

(defun block-group-index (solution-index)
  (+ (floor (/ (mod solution-index 9) 3))
     (* 3 (floor (/ solution-index 27)))))

(defcclass* solution-element ()
  ((index :type integer)
   (selected-value :computed-in compute-as)
   (possible-values :computed-in compute-as)))

(defcclass* solution ()
  ((elements :type list)))

(defmethod initialize-instance ((solution solution) &key initial-solution &allow-other-keys)
  (let ((elements (setf (elements-of solution) (make-array 81))))
    (loop for i :from 0 :below 81
       for initial-element = (when initial-solution
                               (elt (elt initial-solution (row-group-index i))
                                    (mod i 9)))
       do (setf (aref elements i)
                (make-instance 'solution-element
                               :index i
                               :selected-value (rebind (initial-element)
                                                 (compute-as initial-element)))))
    (loop for i :from 0 :below 81
       do (setf (possible-values-of (aref elements i))
                (let ((depends-on-elements
                       (mapcar (curry #'aref elements)
                               (remove-duplicates
                                (remove i
                                        (loop for j :from 0 :below 9
                                           collect (row-element-index (row-group-index i) j)
                                           collect (column-element-index (column-group-index i) j)
                                           collect (block-element-index (block-group-index i) j)))))))
                  (compute-as
                    (set-difference '(1 2 3 4 5 6 7 8 9)
                                    (mapcar #'selected-value-of depends-on-elements))))))))

(defmethod print-object ((solution solution) stream)
  (loop for i :from 0 :below 81 do
       (unless (zerop i)
         (when (zerop (mod i 3))
           (format stream "   "))
         (when (zerop (mod i 9))
           (terpri stream)
           (when (zerop (mod i 27))
             (terpri stream))))
       (format stream "~A " (selected-value-of (aref (elements-of solution) i)))))

(defun solve (solution)
  (labels ((continue-solving ()
             (let ((elements (elements-of solution)))
               (stable-sort elements #'< :key (compose #'length #'possible-values-of))
               (loop for element :across elements
                  unless (selected-value-of element)
                  do (loop for value :in (possible-values-of element)
                        do
                        (setf (selected-value-of element) value)
                        (continue-solving)
                        finally
                        (setf (selected-value-of element) nil)
                        (return-from continue-solving))
                  finally
                  (sort elements #'< :key #'index-of)
                  (return-from solve solution)))))
    (continue-solving)))