1import torch
2from torch import Tensor
3
4from vedo import Mesh
5
6import hashlib, pickle
7
8from acoustools.Utilities import device, DTYPE, forward_model_batched, forward_model_grad, forward_model_second_derivative_unmixed
9from acoustools.BEM.Forward_models import compute_bs, compute_A, get_cache_or_compute_H
10from acoustools.Mesh import get_centres_as_points, get_normals_as_points, board_name
11import acoustools.Constants as Constants
12
13from acoustools.BEM.Gradients.E_Gradient import get_G_partial
14
15
16
17def grad_H(points: Tensor, scatterer: Mesh, transducers: Tensor, return_components:bool = False,
18 path:str='', H:Tensor=None, use_cache_H:bool=True) ->tuple[Tensor,Tensor, Tensor] | tuple[Tensor,Tensor, Tensor, Tensor,Tensor, Tensor, Tensor]:
19 '''
20 @private
21 Computes the gradient of H wrt scatterer centres\n
22 Ignores `points` - for compatability with other gradient functions, takes centres of the scatterers
23 :param scatterer: The mesh used (as a `vedo` `mesh` object)
24 :param transducers: Transducers to use
25 :param return_components: if true will return the subparts used to compute the derivative
26 :return grad_H: The gradient of the H matrix wrt the position of the mesh
27 '''
28 print("Implementation not tested H grad - probably do correct")
29 if H is None:
30 H = get_cache_or_compute_H(scatterer, transducers, use_cache_H, path)
31
32
33 # centres = torch.tensor(scatterer.cell_centers().points).to(device).T.unsqueeze_(0)
34 centres = get_centres_as_points(scatterer)
35
36
37 M = centres.shape[2]
38
39 B = compute_bs(scatterer,transducers)
40 A = compute_A(scatterer)
41 A_inv = torch.inverse(A).to(DTYPE)
42
43
44 Bx, By, Bz = forward_model_grad(centres, transducers)
45 Bx = Bx.to(DTYPE)
46 By = By.to(DTYPE)
47 Bz = Bz.to(DTYPE)
48
49
50 Ax, Ay, Az = get_G_partial(centres,scatterer,transducers)
51 # Ax *= -1
52 # Ay *= -1
53 # Az *= -1
54
55 Ax = (-1* Ax)
56 Ay = (-1* Ay)
57 Az = (-1* Az)
58
59
60
61 eye = torch.eye(M).to(bool)
62 Ax[:,eye] = 0
63 Ay[:,eye] = 0
64 Az[:,eye] = 0
65
66 # A_inv_x = (-1*A_inv @ Ax @ A_inv).to(DTYPE)
67 # A_inv_y = (-1*A_inv @ Ay @ A_inv).to(DTYPE)
68 # A_inv_z = (-1*A_inv @ Az @ A_inv).to(DTYPE)
69
70 # Hx_old = (A_inv_x@B) + (A_inv@Bx)
71 # Hy_old = (A_inv_y@B) + (A_inv@By)
72 # Hz_old = (A_inv_z@B) + (A_inv@Bz)
73
74
75 Hx = A_inv @ (Bx - Ax @ H)
76 Hy = A_inv @ (By - Ay @ H)
77 Hz = A_inv @ (Bz - Az @ H)
78
79
80 Hx = Hx.to(DTYPE)
81 Hy = Hy.to(DTYPE)
82 Hz = Hz.to(DTYPE)
83
84 if return_components:
85 return Hx, Hy, Hz, A, A_inv, Ax, Ay, Az
86 else:
87 return Hx, Hy, Hz
88
89
90def grad_2_H(points: Tensor, scatterer: Mesh, transducers: Tensor, A:Tensor|None = None,
91 A_inv:Tensor|None = None, Ax:Tensor|None = None, Ay:Tensor|None = None, Az:Tensor|None = None) -> Tensor:
92 '''
93 @private
94 Computes the second derivative of H wrt scatterer centres\n
95 Ignores `points` - for compatability with other gradient functions, takes centres of the scatterers
96 :param scatterer: The mesh used (as a `vedo` `mesh` object)
97 :param transducers: Transducers to use
98 :param A: The result of a call to `compute_A`
99 :param A_inv: The inverse of `A`
100 :param Ax: The gradient of A wrt the x position of scatterer centres
101 :param Ay: The gradient of A wrt the y position of scatterer centres
102 :param Az: The gradient of A wrt the z position of scatterer centres
103 :return Haa: second order unmixed gradient of H wrt scatterer positions
104 '''
105 print("Implementation not tested H grad - probably do correct")
106 centres = get_centres_as_points(scatterer)
107 M = centres.shape[2]
108
109 B = compute_bs(scatterer,transducers)
110
111 Fx, Fy, Fz = forward_model_grad(centres, transducers)
112 Fx = Fx.to(DTYPE)
113 Fy = Fy.to(DTYPE)
114 Fz = Fz.to(DTYPE)
115 Fa = torch.stack([Fx,Fy,Fz],dim=3)
116
117 Fxx, Fyy, Fzz = forward_model_second_derivative_unmixed(centres, transducers)
118 Faa = torch.stack([Fxx,Fyy,Fzz],dim=3)
119
120 F = forward_model_batched(centres, transducers)
121
122 if A is None:
123 A = compute_A(scatterer)
124
125 if A_inv is None:
126 A_inv = torch.inverse(A)
127
128 if Ax is None or Ay is None or Az is None:
129 Ax, Ay, Az = get_G_partial(centres,scatterer,transducers)
130 eye = torch.eye(M).to(bool)
131 Ax[:,eye] = 0
132 Ay[:,eye] = 0
133 Az[:,eye] = 0
134 Ax = Ax.to(DTYPE)
135 Ay = Ay.to(DTYPE)
136 Az = Az.to(DTYPE)
137 Aa = torch.stack([Ax,Ay,Az],dim=3)
138
139
140 A_inv_x = (-1*A_inv @ Ax @ A_inv).to(DTYPE)
141 A_inv_y = (-1*A_inv @ Ay @ A_inv).to(DTYPE)
142 A_inv_z = (-1*A_inv @ Az @ A_inv).to(DTYPE)
143
144
145 A_inv_a = torch.stack([A_inv_x,A_inv_y,A_inv_z],dim=3)
146
147 m = centres.permute(0,2,1)
148 m = m.expand((M,M,3))
149
150 m_prime = m.clone()
151 m_prime = m_prime.permute((1,0,2))
152
153 vecs = m - m_prime
154 vecs = vecs.unsqueeze(0)
155
156
157 # norms = torch.tensor(scatterer.cell_normals).to(device)
158 norms = get_normals_as_points(scatterer,permute_to_points=False)
159 norms = norms.expand(1,M,-1,-1)
160
161 norm_norms = torch.norm(norms,2,dim=3)
162 vec_norms = torch.norm(vecs,2,dim=3)
163 vec_norms_cube = vec_norms**3
164 vec_norms_five = vec_norms**5
165
166 distance = torch.sqrt(torch.sum(vecs**2,dim=3))
167 vecs_square = vecs **2
168 distance_exp = torch.unsqueeze(distance,3)
169 distance_exp = distance_exp.expand(-1,-1,-1,3)
170
171 distance_exp_cube = distance_exp**3
172
173 distaa = torch.zeros_like(distance_exp)
174 distaa[:,:,:,0] = (vecs_square[:,:,:,1] + vecs_square[:,:,:,2])
175 distaa[:,:,:,1] = (vecs_square[:,:,:,0] + vecs_square[:,:,:,2])
176 distaa[:,:,:,2] = (vecs_square[:,:,:,1] + vecs_square[:,:,:,0])
177 distaa = distaa / distance_exp_cube
178
179 dista = vecs / distance_exp
180
181 Aaa = (-1 * torch.exp(1j*Constants.k * distance_exp) * (distance_exp*(1-1j*Constants.k*distance_exp))*distaa + dista*(Constants.k**2 * distance_exp**2 + 2*1j*Constants.k * distance_exp -2)) / (4*torch.pi * distance_exp_cube)
182
183 Baa = (distance_exp * distaa - 2*dista**2) / distance_exp_cube
184
185 Caa = torch.zeros_like(distance_exp).to(device)
186
187 vec_dot_norm = vecs[:,:,:,0]*norms[:,:,:,0]+vecs[:,:,:,1]*norms[:,:,:,1]+vecs[:,:,:,2]*norms[:,:,:,2]
188
189 Caa[:,:,:,0] = ((( (3 * vecs[:,:,:,0]**2) / (vec_norms_five) - (1)/(vec_norms_cube))*(vec_dot_norm)) / norm_norms) - ((2*vecs[:,:,:,0]*norms[:,:,:,0]) / (norm_norms*vec_norms_cube**3))
190 Caa[:,:,:,1] = ((( (3 * vecs[:,:,:,1]**2) / (vec_norms_five) - (1)/(vec_norms_cube))*(vec_dot_norm)) / norm_norms) - ((2*vecs[:,:,:,1]*norms[:,:,:,1]) / (norm_norms*vec_norms_cube**3))
191 Caa[:,:,:,2] = ((( (3 * vecs[:,:,:,2]**2) / (vec_norms_five) - (1)/(vec_norms_cube))*(vec_dot_norm)) / norm_norms) - ((2*vecs[:,:,:,2]*norms[:,:,:,2]) / (norm_norms*vec_norms_cube**3))
192
193 Gx, Gy, Gz, A_green, B_green, C_green, Aa_green, Ba_green, Ca_green = get_G_partial(centres, scatterer, transducers, return_components=True)
194
195 Gaa = 2*Ca_green*(B_green*Aa_green + A_green*Ba_green) + C_green*(B_green*Aaa + 2*Aa_green*Ba_green + A_green*Baa)+ A_green*B_green*Caa
196 Gaa = Gaa.to(DTYPE)
197
198 areas = torch.Tensor(scatterer.celldata["Area"]).to(device)
199 areas = torch.unsqueeze(areas,0)
200 areas = torch.unsqueeze(areas,0)
201 areas = torch.unsqueeze(areas,3)
202
203 Gaa = Gaa * areas
204 # Gaa = torch.nan_to_num(Gaa)
205 eye = torch.eye(Gaa.shape[2]).to(bool)
206 Gaa[:,eye] = 0
207
208
209 A_inv_a = A_inv_a.permute(0,3,2,1)
210 Fa = Fa.permute(0,3,1,2)
211
212 A_inv = A_inv.unsqueeze(1).expand(-1,3,-1,-1)
213 Faa = Faa.permute(0,3,1,2)
214
215 Fa = Fa.to(DTYPE)
216 Faa = Faa.to(DTYPE)
217
218 Gaa = Gaa.permute(0,3,2,1)
219 Aa = Aa.permute(0,3,2,1)
220 Aa = Aa.to(DTYPE)
221
222 X1 = A_inv_a @ Fa + A_inv @ Faa
223 X2 = (A_inv @ (Aa @ A_inv @ Aa - Gaa)@A_inv) @ F
224 X3 = A_inv_a@Fa
225
226
227 Haa = X1 + X2 + X3
228
229 return Haa
230
231
232def get_cache_or_compute_H_2_gradients(scatterer:Mesh,board:Tensor,use_cache_H_grad:bool=True, path:str="Media", print_lines:bool=False) -> Tensor:
233 '''
234 @private
235 Get second derivatives of H using cache system. Expects a folder named BEMCache in `path`\n
236 :param scatterer: The mesh used (as a `vedo` `mesh` object)
237 :param board: Transducers to use
238 :param use_cache_H_grad: If true uses the cache system, otherwise computes H and does not save it
239 :param path: path to folder containing `BEMCache/ `
240 :param print_lines: if true prints messages detaling progress
241 :return: second derivatives of H
242 '''
243 print("Implementation not tested H grad - probably do correct")
244 if use_cache_H_grad:
245
246 f_name = scatterer.filename+"--"+ board_name(board)
247 f_name = hashlib.md5(f_name.encode()).hexdigest()
248 f_name = path+"/BEMCache/" + f_name +"_2grad"+ ".bin"
249
250 try:
251 if print_lines: print("Trying to load H 2 grads at", f_name ,"...")
252 Haa = pickle.load(open(f_name,"rb"))
253 Haa = Haa.to(device)
254 except FileNotFoundError:
255 if print_lines: print("Not found, computing H grad 2...")
256 Haa = grad_2_H(None, transducers=board, **{"scatterer":scatterer })
257 f = open(f_name,"wb")
258 pickle.dump(Haa,f)
259 f.close()
260 else:
261 if print_lines: print("Computing H grad 2...")
262 Haa = grad_2_H(None, transducers=board, **{"scatterer":scatterer })
263
264 return Haa
265
266
267def get_cache_or_compute_H_gradients(scatterer,board,use_cache_H_grad=True, path="Media", print_lines=False) -> tuple[Tensor, Tensor, Tensor]:
268 '''
269 @private
270 Get derivatives of H using cache system. Expects a folder named BEMCache in `path`\\
271 :param scatterer: The mesh used (as a `vedo` `mesh` object)\\
272 :param board: Transducers to use \\
273 :param use_cache_H_grad: If true uses the cache system, otherwise computes H and does not save it\\
274 :param path: path to folder containing BEMCache/ \\
275 :param print_lines: if true prints messages detaling progress\\
276 Returns derivatives of H
277 '''
278 print("Implementation not tested for H grad - probably do correct")
279 if use_cache_H_grad:
280
281 f_name = scatterer.filename +"--"+ board_name(board)
282 f_name = hashlib.md5(f_name.encode()).hexdigest()
283 f_name = path+"/BEMCache/" + f_name +"_grad"+ ".bin"
284
285 try:
286 if print_lines: print("Trying to load H grads at", f_name ,"...")
287 Hx, Hy, Hz = pickle.load(open(f_name,"rb"))
288 Hx = Hx.to(device)
289 Hy = Hy.to(device)
290 Hz = Hz.to(device)
291 except FileNotFoundError:
292 if print_lines: print("Not found, computing H Grads...")
293 Hx, Hy, Hz = grad_H(None, transducers=board, **{"scatterer":scatterer }, path=path)
294 f = open(f_name,"wb")
295 pickle.dump((Hx, Hy, Hz),f)
296 f.close()
297 else:
298 if print_lines: print("Computing H Grad...")
299 Hx, Hy, Hz = grad_H(None, transducers=board, **{"scatterer":scatterer }, path=path)
300
301 return Hx, Hy, Hz