Commit 4cfe7547 authored by Jonathan Lambrechts's avatar Jonathan Lambrechts
Browse files

faster csr creation

parent 640e8a7b
Pipeline #9844 passed with stages
in 22 minutes
import numpy as np
import time
class CSR :
def __init__(self, idx, rhsidx,constraints):
pairs = np.ndarray([idx.shape[0],idx.shape[1],idx.shape[1]],dtype=([('i0',np.int32),('i1',np.int32)]))
pairs['i0'][:,:,:] = idx[:,:,None]
pairs['i1'][:,:,:] = idx[:,None,:]
pairs = pairs.reshape([-1])
allpairs = [pairs.reshape([-1])]
shift = 2**32
num = np.max(idx)
self.ndof = num+1
idx = idx.astype('uint64')
pairs = (idx[:,:,None]*shift+idx[:,None,:]).flatten()
allpairs = [pairs]
self.constraints = constraints
for c in constraints :
num += 1
pairs = np.ndarray([c.size*2+1],dtype=([('i0',np.int32),('i1',np.int32)]))
pairs['i0'][:c.size] = c
pairs['i1'][:c.size] = num
pairs['i0'][c.size:c.size*2] = num
pairs['i1'][c.size:c.size*2] = c
pairs['i0'][c.size*2] = num
pairs['i1'][c.size*2] = num
pairs = np.ndarray([c.size*2+1],np.uint64)
pairs[:c.size] = c*shift+num
pairs[c.size:c.size*2] = num*shift+c
pairs[c.size*2] = num*shift+num
allpairs.append(pairs)
pairs = np.concatenate(allpairs)
pairs, pmap = np.unique(pairs,return_inverse=True)
......@@ -28,9 +25,9 @@ class CSR :
self.map.append(pmap[count:count+p.size])
count += p.size
self.row = np.hstack([np.array([0],dtype=np.int32),
np.cumsum(np.bincount(pairs["i0"]),
np.cumsum(np.bincount((pairs//shift).astype(np.int32)),
dtype=np.int32)])
self.col = pairs['i1'].copy()
self.col = (pairs%shift).astype(np.int32)
self.size = self.row.size-1
self.rhsidx = rhsidx
......
......@@ -151,7 +151,8 @@ class LinearSystemBAIJ :
self.idx = (self.elements[:,None,:]*n_fields+np.arange(n_fields)[None,:,None]).reshape([-1])
self.size = nnodes*n_fields
self.n_fields = n_fields
csrmap = ((self.csr.map[0]*9)[:,None,None]+np.arange(0,9,3)[None,None,:]+np.arange(0,3)[None,:,None])
csrmapl = np.arange(nn*nn)[None,:].reshape(nn,nn).T
csrmap = (self.csr.map[0]*nn*nn)[:,None,None]+csrmapl[None,:,:]
self.csrmap = np.copy(np.swapaxes(csrmap.reshape([elements.shape[0],nn,nn, n_fields,n_fields]),2,3).reshape([-1]))
self.val = np.zeros(self.csr.col.size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment