# -*- Python -*-

if __name__ == '__main__':
    import sys, petsc4py
    petsc4py.init(sys.argv)
    del sys, petsc4py
    

import petsc4py.PETSc as PETSc
import numpy as array

COMM = PETSc.COMM_WORLD
SIZE = COMM.size
RANK = COMM.rank

opts = PETSc.Options()

USE_SCHUR = ('schur')   in opts
MONITOR   = ('monitor') in opts
VIEW      = ('view')    in opts

#M, N = (101,101)
#M, N = (51,51)
#M, N = (31,31)
#M, N = (9,9)
#M, N = (7,7)

#M, N = (200,200)
M, N = (100,100)
#M, N = (64,64)

#M, N = (16,16)
#M, N = (8,8)
#M, N = (6,6)

if 1:
    from mpi4py import MPI
    import random
    ggrid = array.arange(M*N)
    random.shuffle(ggrid)
    ggrid = MPI.COMM_WORLD.Bcast(ggrid)
    ggrid.shape = (M,N)
else:
    ggrid = array.arange(M*N)
    ggrid.shape = (M,N)

if RANK//2 == 0:
    ib, ie = 0, M//2
else:
    ib, ie = M//2, M 

if RANK%2 == 0:
    jb, je = 0, N//2
else:
    jb, je = N//2, N 


if SIZE==1:
    ib, ie = 0, M 
    jb, je = 0, N

ao = PETSc.AOMapping(ggrid[ib:ie,jb:je].copy())

if SIZE>1:
    if RANK//2 == 0:
        ie += 1
    if RANK%2 == 0:
        je += 1
    
lgrid = ggrid[ib:ie,jb:je].copy()
m, n = lgrid.shape

ao.ApplicationToPetsc(lgrid)

quads = array.zeros(((m-1)*(n-1),4), dtype=PETSc.Int)
quads[...,0] = lgrid[ 0:m-1 , 0:n-1 ].flat
quads[...,1] = lgrid[ 1:m   , 0:n-1 ].flat
quads[...,2] = lgrid[ 1:m   , 1:n   ].flat
quads[...,3] = lgrid[ 0:m-1 , 1:n   ].flat
quads.shape = (quads.size/4,4)


A = PETSc.Mat()
A.create(comm=COMM)
A.setSizes(M*N)
A.setType(PETSc.Mat.Type.AIJ)
A.setPreallocation([5, 1])

b, x = A.getVecs()

# element matrix
em = 1.0/3.0 * array.array([[  2,   -0.5, -1,   -0.5 ],
                            [ -0.5,  2,   -0.5, -1   ],
                            [ -1,   -0.5,  2,   -0.5 ],
                            [ -0.5, -1,   -0.5,  2   ]],
                           dtype=PETSc.Scalar)
# element rhs
ev = 1.0/4.0 * array.array([1,1,1,1],dtype=PETSc.Scalar)

fix = [ggrid[0,0],
       ggrid[0,-1],
       ggrid[-1,0],
       ggrid[-1,-1],
       ]

if SIZE>1:
    fix = [fix[RANK]]

fix = array.array(fix,dtype=PETSc.Int)
ao.ApplicationToPetsc(fix)


kappa=1
#kappa=(RANK+1)
ADD = PETSc.InsertMode.ADD_VALUES
for q in quads:
    A.setValues(q,q,kappa*em,ADD)
A.assemble()
A.zeroRows(fix)

for q in quads:
    b.setValues(q,ev,ADD)
b.assemble()
#b.set(1)
#b.setRandom()
b.setValues(fix,[0]*len(fix))
b.assemble()

#PETSc.SyncPrint('[%d] %s\n' % (RANK, A.range))
#PETSc.SyncFlush()
#PETSc.Print('%s\n\n' % ggrid)

#h = 1./(M-1)**2
#A.scale(1./h)
#b.scale(h)

x.set(0)

if MONITOR:
    opts['ksp_monitor'] = 'stdout'

if USE_SCHUR:
    #opts["ksp_type"]     = "preonly"
    opts["ksp_type"]     = "fgmres"
    
    opts["pc_type"]      = "schur"
    opts["sub_ksp_type"] = "cg"
    #opts['sub_pc_type']  = 'none'
    
    opts['pc_schur_print_stats'] = 1

    if MONITOR:
        opts['sub_ksp_monitor'] = 'stdout'

    if SIZE==1:
        if not "pc_schur_local_blocks" in opts:
            opts['pc_schur_local_blocks'] = 4

else:
    opts['ksp_type'] = 'cg'

if VIEW:
    opts['ksp_view'] = None


ksp = PETSc.KSP().create(comm=COMM)
ksp.setOperators(A,A,PETSc.Mat.Structure.SAME)
ksp.setFromOptions()
try:
    ksp.setUp()
    #if VIEW:
    #    ksp.view()
except PETSc.Error:
    PETSc.Error.view()
    raise
    
## PETSc.SyncFlush();
## PETSc.Print('%s\n' % ggrid)
## PETSc.SyncPrint('%s\n' % lgrid)
## PETSc.SyncFlush();


ksp.solve(b,x)

#draw = PETSc.ViewerDraw(title='Matrix');
#info = PETSc.ViewerASCII(name='stdout',format='info');

#draw(x)

#if not RANK: print array.arange(m*n).reshape(m,n)

#A.scale(0)
#A.view()
