"""Core functions of SRW computations."""
from copy import deepcopy
import multiprocessing as mp
import numpy as np
from . import srwlib
from . import configuration
from .propagation import groupOpticalSystem, cropOpticalSystem, \
parseToSRWelement
from .tools import tupledeepcopy
from .wavefronts import Observer, Wavefront, WavefrontMultiWavelength, Stokes
# ========================== TRAJECTORY COMPUTATION ===========================
[docs]
def computeTrajectory(particle, magnetsContainer,
sStart=None, sEnd=None, res=None):
"""Compute the trajectory inside a magnetic structure.
Args:
particle: an instance of :py:class:`~.emitters.Particle` or
:py:class:`~.emitters.ParticleBeam`.
magnetsContainer: the magnet container instance of
:py:class:`~.magnets.MagnetsContainer`
sStart (optional): start of the trajectory computation in
:math:`c \cdot t`. If None, the `zStart` of the `magnetsContainer`
is used. Defaults to None.
sEnd (optional): end of the trajectory computation in
:math:`c \cdot t`. If None, the `zEnd` of the `magnetsContainer`
is used. Defaults to None.
res (optional): resolution of the trajectory computation in
:math:`c \cdot t`. If None, the `nPts` of the `magnetsContainer`
is used. Defaults to None.
Returns:
dict: trajectory result as dictionary with the positions, deflections
and transverse field experienced by the particle during the motion.
"""
if hasattr(particle, "nominalParticle"):
part = deepcopy(particle.nominalParticle)
else:
part = deepcopy(particle)
part.relE0 = 1 # field already scaled
part.nq = -1 # need inversion to preserve direction convention
zStart, zEnd, nPtsMag = magnetsContainer.getIntegrationParams()
if sStart is None:
sStart = zStart
if sEnd is None:
sEnd = zEnd
if res is None:
nPtsTraj = nPtsMag
else:
nPtsTraj = int((sEnd - sStart)/res)
trj = srwlib.SRWLPrtTrj(_partInitCond=part, _ctStart=sStart, _ctEnd=sEnd,
_np=nPtsTraj)
trj.allocate(nPtsTraj, True)
trj = srwlib.srwl.CalcPartTraj(trj, magnetsContainer.srwCnt, [1])
trajectory = {
"z": np.array(trj.arZ), "x": np.array(trj.arX), "y": np.array(trj.arY),
"xp": np.array(trj.arXp), "yp": np.array(trj.arYp),
"Bx": np.array(trj.arBx) / configuration.CONFIG.B_SCALING,
"By": np.array(trj.arBy) / configuration.CONFIG.B_SCALING
}
return trajectory
# ============================ SPECTRUM COMPUTATION ===========================
[docs]
def computeSpectrum(particleBeam, magnetsContainer, observer,
fromWavelength=None, toWavelength=None, numPoints=None,
relPrec=None, srApprox=None):
"""Compute the spectrum of synchrotron radiation emitted by a particle
travelling along a magnetic structure.
Args:
particleBeam: particle instance of :py:class:`~.emitters.ParticleBeam`.
magnetsContainer: container instance of
:py:class:`~.magnets.MagnetsContainer`.
observer: mesh instance of :py:class:`~.wavefronts.Observer`.
Check :py:class:`~.wavefronts.WavefrontMultiWavelength` for further
details about the observer definition.
fromWavelength: lower wavelength in nanometers.
toWavelength: upper wavelength in nanometers.
numPoints: number of wavelengths to simulate.
relPrec (optional): relative precision target for the convergence of
the computation. If None, the default value set in
:py:class:`~.configuration.CONFIG` is used. Defaults to None.
calcMeth (optional): one of the :py:data:`.configuration.SR_APPROX`
approximation methods to use for the computation. If None, the
default value set in :py:class:`~.configuration.CONFIG` is used.
Defaults to None.
Returns:
:py:class:`~.wavefronts.WavefrontMultiWavelength`: the calculated
field spectral distribution.
"""
if relPrec is None:
relPrec = configuration.CONFIG.REL_PREC
if srApprox is None:
srApprox = configuration.CONFIG.SR_APPROX
wfr = WavefrontMultiWavelength(observer,
fromWavelength, toWavelength, numPoints)
wfr.partBeam = particleBeam # attach the beam required by CalcElecFieldSR
zStart, zEnd, nPts = magnetsContainer.getIntegrationParams()
params = [configuration.SR_APPROX[srApprox], relPrec, zStart, zEnd, nPts,
1, 0]
srwlib.srwl.CalcElecFieldSR(wfr, 0, magnetsContainer.srwCnt, params)
return wfr
# =========================== WAVEFRONT COMPUTATION ===========================
[docs]
def computePtSrcWfr(pointSource, observer, wavelength):
"""Compute the wavefront of a point source.
Args:
pointSource: point source instance of :py:class:`~.emitters.PointSource`.
observer: observation mesh instance of :py:class:`~.wavefronts.Observer`.
wavelength: observation wavelength in nanometers.
Returns:
:py:class:`~.wavefronts.Wavefront`: the calculated wavefront.
"""
wfr = Wavefront(observer, wavelength)
srwlib.srwl.CalcElecFieldPointSrc(wfr, pointSource, [0])
return wfr
[docs]
def computePtSrcWfrMultiProcess(numCores, pointSource, observer, wavelength):
"""Compute the wavefront of a point source.
Multiprocessing version of :py:func:`~computePtSrcWfr`.
Args:
numCores: the number of processes to split the computation into.
Maximum value limited to the CPU count of the machine.
pointSource: point source instance of :py:class:`~.emitters.PointSource`.
observer: observation mesh instance of :py:class:`~.wavefronts.Observer`.
wavelength: observation wavelength in nanometers.
Returns:
:py:class:`~.wavefronts.Wavefront`: the calculated wavefront.
"""
inputs = locals()
inputs.pop("numCores")
wfr = _computeWfrMultiProcess(numCores, _compPtSrcWrapper, inputs)
return wfr
[docs]
def computeGsnBeamWfr(gaussianBeam, observer, wavelength):
"""Compute the wavefront of a coherent Gaussian beam.
Args:
gaussianBeam: source instance of :py:class:`~.emitters.GaussianBeam`.
observer: mesh instance of :py:class:`~.wavefronts.Observer`.
wavelength: observation wavelength in nanometers.
Returns:
:py:class:`~.wavefronts.Wavefront`: the calculated wavefront.
"""
wfr = Wavefront(observer, wavelength)
srwlib.srwl.CalcElecFieldGaussian(wfr, gaussianBeam, [0])
return wfr
[docs]
def computeGsnBeamWfrMultiProcess(numCores, gaussianBeam, observer, wavelength):
"""Compute the wavefront of a coherent Gaussian beam.
Multiprocessing version of :py:func:`~computeGsnBeamWfr`.
Args:
numCores: the number of processes to split the computation into.
Maximum value limited to the CPU count of the machine.
gaussianBeam: source instance of :py:class:`~.emitters.GaussianBeam`.
observer: mesh instance of :py:class:`~.wavefronts.Observer`.
wavelength: observation wavelength in nanometers.
Returns:
:py:class:`~.wavefronts.Wavefront`: the calculated wavefront.
"""
inputs = locals()
inputs.pop("numCores")
wfr = _computeWfrMultiProcess(numCores, _compGsnBeamWrapper, inputs)
return wfr
[docs]
def computeSrWfr(particleBeam, magnetsContainer, observer, wavelength,
relPrec=None, srApprox=None):
"""Compute the synchrotron radiation wavefront emitted by a particle
travelling along a magnetic structure.
Args:
particleBeam: particle instance of :py:class:`~.emitters.ParticleBeam`.
magnetsContainer: container instance of
:py:class:`~.magnets.MagnetsContainer`.
observer: mesh instance of :py:class:`~.wavefronts.Observer`.
wavelength : observation wavelength in nanometers.
relPrec (optional): relative precision target for the convergence of
the computation. If None, the default value set in
:py:class:`~.configuration.CONFIG` is used. Defaults to None.
calcMeth (optional): one of the :py:data:`.configuration.SR_APPROX`
approximation methods to use for the computation. If None, the
default value set in :py:class:`~.configuration.CONFIG` is used.
Defaults to None.
Returns:
:py:class:`~.wavefronts.Wavefront`: the calculated wavefront.
"""
if relPrec is None:
relPrec = configuration.CONFIG.REL_PREC
if srApprox is None:
srApprox = configuration.CONFIG.SR_APPROX
wfr = Wavefront(observer, wavelength)
wfr.partBeam = particleBeam # attach the beam required by CalcElecFieldSR
zStart, zEnd, nPts = magnetsContainer.getIntegrationParams()
params = [configuration.SR_APPROX[srApprox], relPrec, zStart, zEnd, nPts,
1, 0]
srwlib.srwl.CalcElecFieldSR(wfr, 0, magnetsContainer.srwCnt, params)
return wfr
[docs]
def computeSrWfrMultiProcess(numCores, particleBeam, magnetsContainer,
observer, wavelength,
relPrec=None, srApprox=None):
"""Compute the wavefront of a coherent Gaussian beam.
Multiprocessing version of :py:func:`~computeSrWfr`.
Args:
numCores: the number of processes to split the computation into.
Maximum value limited to the CPU count of the machine.
particleBeam: particle instance of :py:class:`~.emitters.ParticleBeam`.
magnetsContainer: container instance of
:py:class:`~.magnets.MagnetsContainer`.
observer: mesh instance of :py:class:`~.wavefronts.Observer`.
wavelength : observation wavelength in nanometers.
relPrec (optional): relative precision target for the convergence of
the computation. If None, the default value set in
:py:class:`~.configuration.CONFIG` is used. Defaults to None.
calcMeth (optional): one of the :py:data:`.configuration.SR_APPROX`
approximation methods to use for the computation. If None, the
default value set in :py:class:`~.configuration.CONFIG` is used.
Defaults to None.
Returns:
:py:class:`~.wavefronts.Wavefront`: the calculated wavefront.
"""
inputs = locals()
inputs.pop("numCores")
wfr = _computeWfrMultiProcess(numCores, _compSrWrapper, inputs)
return wfr
def _computeWfrMultiProcess(numCores, compWrapper, inputs):
# base function for mutliprocessing wavefront computation
# compFunc is the computation worker, inptus is a dictionary with
# {argument_name: value}.
if numCores > mp.cpu_count():
numCores = mp.cpu_count()
print(f"Limiting cores to maximum available of {numCores}")
# prepare wavefront with full observer
observer = inputs["observer"]
wavelength = inputs["wavelength"]
wfrFull = Wavefront(observer, wavelength)
wfrFull.arEx = srwlib.array("f")
wfrFull.arEy = srwlib.array("f")
# create list of sub-observers
obsList = []
for yax in np.array_split(observer.yax, numCores):
obs = Observer(
centerCoord=[observer.centerCoord[0],
np.mean([min(yax), max(yax)]),
observer.centerCoord[2]],
obsXextension=observer.obsXextension,
obsYextension=max(yax)-min(yax),
nx=len(observer.xax),
ny=len(yax))
obsList.append(obs)
# run multiprocess computation
pool = mp.Pool(numCores)
poolInputsList = []
for i in range(numCores):
poolInputs = deepcopy(inputs)
poolInputs["observer"] = obsList[i]
poolInputsList.append(poolInputs)
wfrSplits = pool.map(compWrapper, poolInputsList)
pool.close()
pool.join()
# collect and merge results
for i in range(numCores):
wfrFull.arEx.extend(wfrSplits[i].arEx)
wfrFull.arEy.extend(wfrSplits[i].arEy)
attribTocopy = ["xc", "yc", "numTypeElFld", "dRx", "dRy", "avgPhotEn",
"arElecPropMatr", "Rx", "Ry"]
for attrib in attribTocopy:
setattr(wfrFull, attrib, getattr(wfrSplits[0], attrib))
return wfrFull
# only module-level functions can be pickled by pool.map
def _compGsnBeamWrapper(poolInputs):
return computeGsnBeamWfr(**poolInputs)
def _compPtSrcWrapper(poolInputs):
return computePtSrcWfr(**poolInputs)
def _compSrWrapper(poolInputs):
return computeSrWfr(**poolInputs)
# =============================== POWER DENSITY ===============================
[docs]
def computeSrPowerDensity(particleBeam, magnetsContainer, observer,
relPrec=None):
"""Compute the synchrotron radiation power density.
Args:
particleBeam: particle instance of :py:class:`~.emitters.ParticleBeam`.
magnetsContainer: container instance of
:py:class:`~.magnets.MagnetsContainer`.
observer: mesh instance of :py:class:`~.wavefronts.Observer`.
wavelength : observation wavelength in nanometers.
relPrec (optional): precision factor for computation.
Pass > 1.0 for more precision.
Returns:
dict: a dictionary containing the two coordinate axes as long as
the matrix of the power density data in W/mm^2.
"""
if relPrec is None:
relPrec = 1.0
stk = Stokes(observer, 1.) # dummy wavelength not used for CalcPowDenSR
zStart, zEnd, nPts = magnetsContainer.getIntegrationParams()
params = [relPrec, 1, zStart, zEnd, nPts]
srwlib.srwl.CalcPowDenSR(stk, particleBeam, 0,
magnetsContainer.srwCnt, params)
pwrDensity = stk.getPwrDensity()
return pwrDensity
# ================================ PROPAGATION ================================
[docs]
def propagateWfr(wfr, opticalSystem, pol="TOT", groupSeq=["lens", "drift"],
saveWfrAt=[], saveIntAt=None,
startPropAt=None, stopPropAt=None):
"""Propagate a wavefront through an optical system.
Args:
wfr: the wavefront instance of :py:class:`~.wavefronts.Wavefront`:.
opticalSystem: dictionary defining the sequence of optical elements.
pol (optional): one of the :py:const:`~pysrw.configuration.POL`
polarization states or an angle in degrees for any arbitrary
linear polarization.
Defaults to "TOT".
groupSeq (optional): list of elements to merge in the propagation.
The propagation through a sequence is entirely performed at the SRW
library level and the intermediate planes will not appear in the
`propData` output. Use this option for faster and memory-efficient
computation, as the same wavefront object is overridden at each
propagation stage.
Defaults to ["lens", "drift"].
saveWfrAt (optional): list of optical element keys where the wavefront
should be stored in `propData`. If None, the wavefront is always
stored. Defaults to [].
saveIntAt (optional): list of optical element keys where the intensity
should be stored in `propData`. If None, the intensity is always
stored. Defaults to None.
startPropAt (optional): start the propagation at a specific optical
element key. If None, the propagaiton starts at the first element.
Defaults to None.
stopPropAt (optional): stop the propagation at a specific optical
element key. If None, the propagaiton stops at the last element.
Defaults to None.
Returns:
dict: a dictionary containing the wavefronts and intensity data at the
requested propagation planes.
"""
optSystemCrop = cropOpticalSystem(opticalSystem,
startPropAt=None, stopPropAt=None)
propData = {}
optCnts = groupOpticalSystem(optSystemCrop, groupSeq)
for i, optCnt in enumerate(optCnts):
SRWoptEls = []
SRWpropPars = []
for j in range(len(optCnt)):
optEl = next(iter(optCnt[j].values()))
SRWoptEl, SRWpropPar = parseToSRWelement(optEl)
SRWoptEls.append(SRWoptEl)
SRWpropPars.append(SRWpropPar)
SRWoptCnt = srwlib.SRWLOptC(SRWoptEls, SRWpropPars)
if i == 0:
# preserve always the input wavefront
wfrAfter = deepcopy(wfr)
srwlib.srwl.PropagElecField(wfrAfter, SRWoptCnt)
lastOptElKey = next(iter(optCnt[-1]))
if "resParam" in optCnt[-1][lastOptElKey].keys():
resParam = optCnt[-1][lastOptElKey]["resParam"]
wfrAfter = resizeWfr(wfrAfter, **resParam)
if saveWfrAt is None or lastOptElKey in saveWfrAt:
wfrToSave = deepcopy(wfrAfter)
else:
wfrToSave = None
if saveIntAt is None or lastOptElKey in saveIntAt:
intToSave = wfrAfter.getWfrI(pol=pol)
else:
intToSave = None
propData[lastOptElKey] = {"wfr": wfrToSave, "intensity": intToSave}
return propData
[docs]
def propagateWfrMultiProcess(numCores, wfr, opticalSystem, pol="TOT",
groupSeq=["lens", "drift"],
saveWfrAt=[], saveIntAt=None,
startPropAt=None, stopPropAt=None):
"""Propagate many wavefronts through an optical system or a single
wavefront though many optical systems.
Args:
numCores: the number of processes to split the computation into.
Maximum value limited to the CPU count of the machine.
wfr: wavefront instance of :py:class:`~.wavefronts.Wavefront`: or list
of many wavefront objects.
opticalSystem: dictionary defining the sequence of optical elements or
list of many optical system dictionaries.
pol (optional): one of the :py:const:`~pysrw.configuration.POL`
polarization states or an angle in degrees for any arbitrary
linear polarization.
Defaults to "TOT".
groupSeq (optional): list of elements to merge in the propagation.
The propagation through a sequence is entirely performed at the SRW
library level and the intermediate planes will not appear in the
`propData` output. Use this option for faster and memory-efficient
computation, as the same wavefront object is overridden at each
propagation stage.
Defaults to ["lens", "drift"].
saveWfrAt (optional): list of optical element keys where the wavefront
should be stored in `propData`. If None, the wavefront is always
stored. Defaults to [].
saveIntAt (optional): list of optical element keys where the intensity
should be stored in `propData`. If None, the intensity is always
stored. Defaults to None.
startPropAt (optional): start the propagation at a specific optical
element key. If None, the propagaiton starts at the first element.
Defaults to None.
stopPropAt (optional): stop the propagation at a specific optical
element key. If None, the propagaiton stops at the last element.
Defaults to None.
Returns:
list: a list of dictionaries with the propagation data returned by
:py:func:~propagateWfr:.
"""
if numCores > mp.cpu_count():
numCores = mp.cpu_count()
print(f"Limiting cores to maximum available of {numCores}")
parallelWfr = isinstance(wfr, list)
parallelOptSys = isinstance(opticalSystem, list)
if not parallelWfr ^ parallelOptSys:
raise Exception("either many wavefronts or many systems supported")
# run multiprocess computation
pool = mp.Pool(numCores)
poolInputs = []
if parallelWfr:
for singleWfr in wfr:
inputs = deepcopy([singleWfr, opticalSystem, pol, groupSeq,
saveWfrAt, saveIntAt, startPropAt, stopPropAt])
poolInputs.append(inputs)
elif parallelOptSys:
for singleOptSys in opticalSystem:
inputs = deepcopy([wfr, singleOptSys, pol, groupSeq,
saveWfrAt, saveIntAt, startPropAt, stopPropAt])
poolInputs.append(inputs)
propData = pool.starmap(propagateWfr, poolInputs)
pool.close()
pool.join()
return propData
[docs]
def resizeWfr(wfr,
hResChange=1.0, vResChange=1.0,
hRangeChange=1.0, vRangeChange=1.0,
hCentreChange=0.5, vCentreChange=0.5):
"""Change the extension, resolution or centre of a wavefront.
Args:
wfr: the :py:class:`~.wavefronts.Wavefront`: input wavefront.
hResChange (optional): rescaling factor to increase (>1.0) or
decrease (<1.0) the horizontal resolution.
Defaults to 1.0.
vResChange (optional): rescaling factor to increase (>1.0) or
decrease (<1.0) the vertical resolution.
Defaults to 1.0.
hRangeChange (optional): rescaling factor to increase (>1.0) or
decrease (<1.0) the horizontal extension.
Defaults to 1.0.
vRangeChange (optional): rescaling factor to increase (>1.0) or
decrease (<1.0) the vertical extension.
Defaults to 1.0.
hCentreChange (optional): new horizontal centre.
Defaults to 0.0.
vCentreChange (optional): new vertical centre.
Defaults to 0.0.
Returns:
:py:class:`~.wavefronts.Wavefront`: the resized wavefront.
"""
if [hResChange,vResChange, hRangeChange, vRangeChange] == [1, 1, 1, 1]:
# add negligible range change to anyway trigger CentreChange
hRangeChange = hRangeChange + 1 / wfr.mesh.nx
vRangeChange = vRangeChange + 1 / wfr.mesh.ny
# ResizeElecField resize hangs if shift is smaller than binning size
hCentreRel = (hCentreChange - wfr.mesh.xStart) / (wfr.mesh.xFin - wfr.mesh.xStart)
vCentreRel = (vCentreChange - wfr.mesh.yStart) / (wfr.mesh.yFin - wfr.mesh.yStart)
wfr = srwlib.srwl.ResizeElecField(wfr, 'c',
[0, hRangeChange, hResChange, vRangeChange, vResChange,
hCentreRel, vCentreRel])
return wfr
[docs]
def resizeWfrMultiProcess(numCores, wfrList,
hResChange=1.0, vResChange=1.0,
hRangeChange=1.0, vRangeChange=1.0,
hCentreChange=0.5, vCentreChange=0.5):
"""Change the extension, resolution or centre of a wavefront.
Multiprocess version of :py:func:`~resizeWfr`.
Args:
numCores: the number of processes to split the computation into.
Maximum value limited to the CPU count of the machine.
wfr: the :py:class:`~.wavefronts.Wavefront`: input wavefront.
hResChange (optional): rescaling factor to increase (>1.0) or
decrease (<1.0) the horizontal resolution.
Defaults to 1.0.
vResChange (optional): rescaling factor to increase (>1.0) or
decrease (<1.0) the vertical resolution.
Defaults to 1.0.
hRangeChange (optional): rescaling factor to increase (>1.0) or
decrease (<1.0) the horizontal extension.
Defaults to 1.0.
vRangeChange (optional): rescaling factor to increase (>1.0) or
decrease (<1.0) the vertical extension.
Defaults to 1.0.
hCentreChange (optional): new horizontal centre.
Defaults to 0.0.
vCentreChange (optional): new vertical centre.
Defaults to 0.0.
Returns:
:py:class:`~.wavefronts.Wavefront`: the resized wavefront.
"""
if numCores > mp.cpu_count():
numCores = mp.cpu_count()
print(f"Limiting cores to maximum available of {numCores}")
pool = mp.Pool(numCores)
poolInputs = []
for wfr in wfrList:
inputs = deepcopy([wfr,
hResChange, vResChange,
hRangeChange, vRangeChange,
hCentreChange, vCentreChange])
poolInputs.append(inputs)
wfrListOut = pool.starmap(self.resizeWfr, poolInputs)
pool.close()
pool.join()
if len(wfrList) > 1:
return wfrListOut
else:
return wfrListOut[0]