src.acoustools.BEM.Gradients.H_Gradient

  1import torch
  2from torch import Tensor
  3
  4from vedo import Mesh
  5
  6import hashlib, pickle
  7
  8from acoustools.Utilities import device, DTYPE, forward_model_batched, forward_model_grad, forward_model_second_derivative_unmixed
  9from acoustools.BEM.Forward_models import compute_bs, compute_A, get_cache_or_compute_H
 10from acoustools.Mesh import get_centres_as_points, get_normals_as_points, board_name
 11import acoustools.Constants as Constants
 12
 13from acoustools.BEM.Gradients.E_Gradient import get_G_partial
 14
 15
 16 
 17def grad_H(points: Tensor, scatterer: Mesh, transducers: Tensor, return_components:bool = False, 
 18           path:str='', H:Tensor=None, use_cache_H:bool=True) ->tuple[Tensor,Tensor, Tensor] | tuple[Tensor,Tensor, Tensor, Tensor,Tensor, Tensor, Tensor]:
 19    '''
 20    @private
 21    Computes the gradient of H wrt scatterer centres\n
 22    Ignores `points` - for compatability with other gradient functions, takes centres of the scatterers
 23    :param scatterer: The mesh used (as a `vedo` `mesh` object)
 24    :param transducers: Transducers to use 
 25    :param return_components: if true will return the subparts used to compute the derivative
 26    :return grad_H: The gradient of the H matrix wrt the position of the mesh
 27    '''
 28    print("Implementation not tested H grad - probably do correct")
 29    if H is None:
 30        H = get_cache_or_compute_H(scatterer, transducers, use_cache_H, path)
 31    
 32    
 33    # centres = torch.tensor(scatterer.cell_centers().points).to(device).T.unsqueeze_(0)
 34    centres = get_centres_as_points(scatterer)
 35
 36
 37    M = centres.shape[2]
 38
 39    B = compute_bs(scatterer,transducers)
 40    A = compute_A(scatterer)
 41    A_inv = torch.inverse(A).to(DTYPE)
 42
 43    
 44    Bx, By, Bz = forward_model_grad(centres, transducers)
 45    Bx = Bx.to(DTYPE) 
 46    By = By.to(DTYPE)
 47    Bz = Bz.to(DTYPE)
 48
 49
 50    Ax, Ay, Az =  get_G_partial(centres,scatterer,transducers)
 51    # Ax *= -1
 52    # Ay *= -1
 53    # Az *= -1
 54    
 55    Ax = (-1* Ax)
 56    Ay = (-1* Ay)
 57    Az = (-1* Az)
 58
 59
 60    
 61    eye = torch.eye(M).to(bool)
 62    Ax[:,eye] = 0
 63    Ay[:,eye] = 0
 64    Az[:,eye] = 0
 65    
 66    # A_inv_x = (-1*A_inv @ Ax @ A_inv).to(DTYPE)
 67    # A_inv_y = (-1*A_inv @ Ay @ A_inv).to(DTYPE)
 68    # A_inv_z = (-1*A_inv @ Az @ A_inv).to(DTYPE)
 69
 70    # Hx_old = (A_inv_x@B) + (A_inv@Bx)
 71    # Hy_old = (A_inv_y@B) + (A_inv@By)
 72    # Hz_old = (A_inv_z@B) + (A_inv@Bz)
 73
 74
 75    Hx = A_inv @ (Bx - Ax @ H)
 76    Hy = A_inv @ (By - Ay @ H)
 77    Hz = A_inv @ (Bz - Az @ H)
 78
 79
 80    Hx = Hx.to(DTYPE)
 81    Hy = Hy.to(DTYPE)
 82    Hz = Hz.to(DTYPE)
 83
 84    if return_components:
 85        return Hx, Hy, Hz, A, A_inv, Ax, Ay, Az
 86    else:
 87        return Hx, Hy, Hz
 88
 89 
 90def grad_2_H(points: Tensor, scatterer: Mesh, transducers: Tensor, A:Tensor|None = None, 
 91             A_inv:Tensor|None = None, Ax:Tensor|None = None, Ay:Tensor|None = None, Az:Tensor|None = None) -> Tensor:
 92    '''
 93    @private
 94    Computes the second derivative of H wrt scatterer centres\n
 95    Ignores `points` - for compatability with other gradient functions, takes centres of the scatterers
 96    :param scatterer: The mesh used (as a `vedo` `mesh` object)
 97    :param transducers: Transducers to use 
 98    :param A: The result of a call to `compute_A`
 99    :param A_inv: The inverse of `A`
100    :param Ax: The gradient of A wrt the x position of scatterer centres
101    :param Ay: The gradient of A wrt the y position of scatterer centres
102    :param Az: The gradient of A wrt the z position of scatterer centres
103    :return Haa: second order unmixed gradient of H wrt scatterer positions
104    '''
105    print("Implementation not tested H grad - probably do correct")
106    centres = get_centres_as_points(scatterer)
107    M = centres.shape[2]
108
109    B = compute_bs(scatterer,transducers)
110
111    Fx, Fy, Fz = forward_model_grad(centres, transducers)
112    Fx = Fx.to(DTYPE)
113    Fy = Fy.to(DTYPE)
114    Fz = Fz.to(DTYPE)
115    Fa = torch.stack([Fx,Fy,Fz],dim=3)
116
117    Fxx, Fyy, Fzz = forward_model_second_derivative_unmixed(centres, transducers)
118    Faa = torch.stack([Fxx,Fyy,Fzz],dim=3)
119
120    F = forward_model_batched(centres, transducers)
121    
122    if A is None:
123        A = compute_A(scatterer)
124    
125    if A_inv is None:
126        A_inv = torch.inverse(A)
127    
128    if Ax is None or Ay is None or Az is None:
129        Ax, Ay, Az = get_G_partial(centres,scatterer,transducers)
130        eye = torch.eye(M).to(bool)
131        Ax[:,eye] = 0
132        Ay[:,eye] = 0
133        Az[:,eye] = 0
134        Ax = Ax.to(DTYPE)
135        Ay = Ay.to(DTYPE)
136        Az = Az.to(DTYPE)
137    Aa = torch.stack([Ax,Ay,Az],dim=3)
138
139    
140    A_inv_x = (-1*A_inv @ Ax @ A_inv).to(DTYPE)
141    A_inv_y = (-1*A_inv @ Ay @ A_inv).to(DTYPE)
142    A_inv_z = (-1*A_inv @ Az @ A_inv).to(DTYPE)
143
144
145    A_inv_a = torch.stack([A_inv_x,A_inv_y,A_inv_z],dim=3)
146
147    m = centres.permute(0,2,1)
148    m = m.expand((M,M,3))
149
150    m_prime = m.clone()
151    m_prime = m_prime.permute((1,0,2))
152
153    vecs = m - m_prime
154    vecs = vecs.unsqueeze(0)
155    
156
157    # norms = torch.tensor(scatterer.cell_normals).to(device)
158    norms = get_normals_as_points(scatterer,permute_to_points=False)
159    norms = norms.expand(1,M,-1,-1)
160
161    norm_norms = torch.norm(norms,2,dim=3)
162    vec_norms = torch.norm(vecs,2,dim=3)
163    vec_norms_cube = vec_norms**3
164    vec_norms_five = vec_norms**5
165
166    distance = torch.sqrt(torch.sum(vecs**2,dim=3))
167    vecs_square = vecs **2
168    distance_exp = torch.unsqueeze(distance,3)
169    distance_exp = distance_exp.expand(-1,-1,-1,3)
170    
171    distance_exp_cube = distance_exp**3
172
173    distaa = torch.zeros_like(distance_exp)
174    distaa[:,:,:,0] = (vecs_square[:,:,:,1] + vecs_square[:,:,:,2]) 
175    distaa[:,:,:,1] = (vecs_square[:,:,:,0] + vecs_square[:,:,:,2]) 
176    distaa[:,:,:,2] = (vecs_square[:,:,:,1] + vecs_square[:,:,:,0])
177    distaa = distaa / distance_exp_cube
178
179    dista = vecs / distance_exp
180
181    Aaa = (-1 * torch.exp(1j*Constants.k * distance_exp) * (distance_exp*(1-1j*Constants.k*distance_exp))*distaa + dista*(Constants.k**2 * distance_exp**2 + 2*1j*Constants.k * distance_exp -2)) / (4*torch.pi * distance_exp_cube)
182    
183    Baa = (distance_exp * distaa - 2*dista**2) / distance_exp_cube
184
185    Caa = torch.zeros_like(distance_exp).to(device)
186
187    vec_dot_norm = vecs[:,:,:,0]*norms[:,:,:,0]+vecs[:,:,:,1]*norms[:,:,:,1]+vecs[:,:,:,2]*norms[:,:,:,2]
188
189    Caa[:,:,:,0] = ((( (3 * vecs[:,:,:,0]**2) / (vec_norms_five) - (1)/(vec_norms_cube))*(vec_dot_norm)) / norm_norms) - ((2*vecs[:,:,:,0]*norms[:,:,:,0]) / (norm_norms*vec_norms_cube**3))
190    Caa[:,:,:,1] = ((( (3 * vecs[:,:,:,1]**2) / (vec_norms_five) - (1)/(vec_norms_cube))*(vec_dot_norm)) / norm_norms) - ((2*vecs[:,:,:,1]*norms[:,:,:,1]) / (norm_norms*vec_norms_cube**3))
191    Caa[:,:,:,2] = ((( (3 * vecs[:,:,:,2]**2) / (vec_norms_five) - (1)/(vec_norms_cube))*(vec_dot_norm)) / norm_norms) - ((2*vecs[:,:,:,2]*norms[:,:,:,2]) / (norm_norms*vec_norms_cube**3))
192    
193    Gx, Gy, Gz, A_green, B_green, C_green, Aa_green, Ba_green, Ca_green = get_G_partial(centres, scatterer, transducers, return_components=True)
194
195    Gaa = 2*Ca_green*(B_green*Aa_green + A_green*Ba_green) + C_green*(B_green*Aaa + 2*Aa_green*Ba_green + A_green*Baa)+ A_green*B_green*Caa
196    Gaa = Gaa.to(DTYPE)
197
198    areas = torch.Tensor(scatterer.celldata["Area"]).to(device)
199    areas = torch.unsqueeze(areas,0)
200    areas = torch.unsqueeze(areas,0)
201    areas = torch.unsqueeze(areas,3)
202
203    Gaa = Gaa * areas
204    # Gaa = torch.nan_to_num(Gaa)
205    eye = torch.eye(Gaa.shape[2]).to(bool)
206    Gaa[:,eye] = 0
207    
208    
209    A_inv_a = A_inv_a.permute(0,3,2,1)
210    Fa = Fa.permute(0,3,1,2)
211
212    A_inv = A_inv.unsqueeze(1).expand(-1,3,-1,-1)
213    Faa = Faa.permute(0,3,1,2)
214
215    Fa = Fa.to(DTYPE)
216    Faa = Faa.to(DTYPE)
217
218    Gaa = Gaa.permute(0,3,2,1)
219    Aa = Aa.permute(0,3,2,1)
220    Aa = Aa.to(DTYPE)
221
222    X1 = A_inv_a @ Fa + A_inv @ Faa
223    X2 = (A_inv @ (Aa @ A_inv @ Aa - Gaa)@A_inv) @ F
224    X3 = A_inv_a@Fa
225
226
227    Haa = X1 + X2 + X3
228    
229    return Haa
230
231 
232def get_cache_or_compute_H_2_gradients(scatterer:Mesh,board:Tensor,use_cache_H_grad:bool=True, path:str="Media", print_lines:bool=False) -> Tensor:
233    '''
234    @private
235    Get second derivatives of H using cache system. Expects a folder named BEMCache in `path`\n
236    :param scatterer: The mesh used (as a `vedo` `mesh` object)
237    :param board: Transducers to use 
238    :param use_cache_H_grad: If true uses the cache system, otherwise computes H and does not save it
239    :param path: path to folder containing `BEMCache/ `
240    :param print_lines: if true prints messages detaling progress
241    :return: second derivatives of H
242    '''
243    print("Implementation not tested H grad - probably do correct")
244    if use_cache_H_grad:
245        
246        f_name = scatterer.filename+"--"+ board_name(board)
247        f_name = hashlib.md5(f_name.encode()).hexdigest()
248        f_name = path+"/BEMCache/"  +  f_name +"_2grad"+ ".bin"
249
250        try:
251            if print_lines: print("Trying to load H 2 grads at", f_name ,"...")
252            Haa = pickle.load(open(f_name,"rb"))
253            Haa = Haa.to(device)
254        except FileNotFoundError: 
255            if print_lines: print("Not found, computing H grad 2...")
256            Haa = grad_2_H(None, transducers=board, **{"scatterer":scatterer })
257            f = open(f_name,"wb")
258            pickle.dump(Haa,f)
259            f.close()
260    else:
261        if print_lines: print("Computing H grad 2...")
262        Haa = grad_2_H(None, transducers=board, **{"scatterer":scatterer })
263
264    return Haa
265
266 
267def get_cache_or_compute_H_gradients(scatterer,board,use_cache_H_grad=True, path="Media", print_lines=False) -> tuple[Tensor, Tensor, Tensor]:
268    '''
269    @private
270    Get derivatives of H using cache system. Expects a folder named BEMCache in `path`\\
271    :param scatterer: The mesh used (as a `vedo` `mesh` object)\\
272    :param board: Transducers to use \\
273    :param use_cache_H_grad: If true uses the cache system, otherwise computes H and does not save it\\
274    :param path: path to folder containing BEMCache/ \\
275    :param print_lines: if true prints messages detaling progress\\
276    Returns derivatives of H
277    '''
278    print("Implementation not tested for H grad - probably do correct")
279    if use_cache_H_grad:
280        
281        f_name = scatterer.filename +"--"+ board_name(board)
282        f_name = hashlib.md5(f_name.encode()).hexdigest()
283        f_name = path+"/BEMCache/"  +  f_name +"_grad"+ ".bin"
284
285        try:
286            if print_lines: print("Trying to load H grads at", f_name ,"...")
287            Hx, Hy, Hz = pickle.load(open(f_name,"rb"))
288            Hx = Hx.to(device)
289            Hy = Hy.to(device)
290            Hz = Hz.to(device)
291        except FileNotFoundError: 
292            if print_lines: print("Not found, computing H Grads...")
293            Hx, Hy, Hz = grad_H(None, transducers=board, **{"scatterer":scatterer }, path=path)
294            f = open(f_name,"wb")
295            pickle.dump((Hx, Hy, Hz),f)
296            f.close()
297    else:
298        if print_lines: print("Computing H Grad...")
299        Hx, Hy, Hz = grad_H(None, transducers=board, **{"scatterer":scatterer }, path=path)
300
301    return Hx, Hy, Hz