spyglass/faster raytrace.ipynb

26 KiB

In [5]:
# here's an attempt to speed up raytracing so it takes a reasonable amount of
# time to optimize systems
# mainly by leveraging numpy to run all the calculations in a tighter for
# loop than what Python can do

# first let's copy the system from before to test against

from rayoptics.environment import *

opm = OpticalModel()
sm  = opm['seq_model']
osp = opm['optical_spec']
pm = opm['parax_model']

osp['pupil'] = PupilSpec(osp, key=['object', 'pupil'], value=16)
osp['fov'] = FieldSpec(osp, key=['object', 'angle'], value=5, flds=[0., 0.707, 1.], is_relative=True)
osp['wvls'] = WvlSpec([('F', 0.5), (587.5618, 1.0), ('C', 0.5)], ref_wl=1)

opm.radius_mode = True

sm.gaps[0].thi=1e10

def calc_curvature(n, fl):
    return (n-1)*fl

def achromatic(f, v0, v1):
    return (v0-v1)*f/v0, -v1*f/(v0-v1)

def calc_curvatures(ns, fls):
    return tuple((n-1)*f for n, f in zip(ns, fls))

n_bk7 = 1.5168
n_lasf9 = 1.85025
n_f2 =  1.62005

v_bk7 = 64.17
v_lasf9 = 32.16
v_f2 = 36.43
# try for chaining a 3x telescope setup with a second 3x telescope setup
f0 = 150
f0_0, f0_1 = achromatic(f0, v_bk7, v_f2)

f1 = 60
f1_0, f1_1 = achromatic(f1, v_bk7, v_f2)

r0_0, r0_1, r1_0, r1_1 = calc_curvatures((n_bk7, n_f2, n_bk7, n_f2), (f0_0, f0_1, f1_0, f1_1))

sm.add_surface([r1_0, 4, 'N-BK7', 'Schott', 16])
sm.add_surface([1e9, 2, 'N-F2', 'Schott', 16])
sm.add_surface([-r1_1, 30])

opm.update_model()
In [146]:
# now we trace all the functions needed for raytracing until we get to the algorithm,
# having them transfer arrays of values to evaluate instead of a single point

import rayoptics.raytr.trace as raytr_trace
from rayoptics.optical.model_constants import Intfc, Gap, Tfrm, Indx, Zdir
from rayoptics.elem.profiles import Spherical
from rayoptics.elem.surface import Surface

from rayoptics.raytr.traceerror import TraceError, TraceMissedSurfaceError, TraceTIRError, TraceEvanescentRayError

def super_trace_grid(sm, fi, wl=None, num_rays=21,
                   append_if_none=True, **kwargs):
    """ fct is applied to the raw grid and returned as a grid  """
    osp = sm.opt_model.optical_spec
    wvls = osp.spectral_region
    wvl = sm.central_wavelength()
    wv_list = wvls.wavelengths if wl is None else [wvl]
    fld = osp.field_of_view.fields[fi]
    foc = osp.defocus.get_focus()

    # make sure this is imported
    rs_pkg, cr_pkg = raytr_trace.setup_pupil_coords(sm.opt_model,
                                              fld, wvl, foc)
    fld.chief_ray = cr_pkg
    fld.ref_sphere = rs_pkg

    grids = []
    grid_start = np.array([-1., -1.])
    grid_stop = np.array([1., 1.])
    grid_def = [grid_start, grid_stop, num_rays]
    results = None
    for wi, wvl in enumerate(wv_list):
        result = np.expand_dims(rly_trace_grid(sm.opt_model, grid_def, fld, wvl, foc,
                                **kwargs), axis=0)
        if results is None:
            results = result
        else:
            results = np.concatenate((results, result), axis=0)
    rc = wvls.render_colors
    return results, rc

# this generates an array of valid points from the grid
def rly_trace_grid(opt_model, grid_rng, fld, wvl, foc, **kwargs): # from trace_grid
    start = np.array(grid_rng[0])
    stop = grid_rng[1]
    num = grid_rng[2]
    step = np.array((stop - start)/(num - 1))
    grid = []
    
    valid_points = None
    
    for x in np.linspace(start[0], stop[0], num, dtype=np.float64):
        ys = np.linspace(start[1], stop[1], num, dtype=np.float64)
        valid_ys = ys[x**2 + ys**2 <= 1.0]
        points = np.zeros((valid_ys.shape[0], 2))
        points[:, 0] = x
        points[:, 1] = valid_ys
        if valid_points is None:
            valid_points = points
        else:
            valid_points = np.concatenate((valid_points, points))
    return rly_rly_trace(opt_model, valid_points, fld, wvl, **kwargs)

def fld_apply_vignetting(fld, pupils):
    ret = np.copy(pupils)
    if fld.vlx != 0.0:
        ret[pupils[:, 0] < 0.0, 0] *= (1. - fld.vlx)
    if fld.vux != 0.0:
        ret[pupils[:, 0] > 0.0, 0] *= (1. - fld.vux)
    if fld.vly != 0.0:
        ret[pupils[:, 1] < 0.0, 1] *= (1. - fld.vly)
    if fld.vuy != 0.0:
        ret[pupils[:, 1] > 0.0, 1] *= (1. - fld.vuy)
    return ret

def norm(v):
    return np.sqrt(np.sum(v*v, axis=1))
    
# some final data conditioning before the good part
def rly_rly_trace(opt_model, pupils, fld, wvl, **kwargs): # from trace_base
    vig_pupils = fld_apply_vignetting(fld, pupils)
    osp = opt_model.optical_spec
    fod = osp.parax_data.fod
    eprad = fod.enp_radius
    aim_pt = np.array([0., 0.])
    if hasattr(fld, 'aim_pt') and fld.aim_pt is not None:
        aim_pt = np.array(fld.aim_pt)
    pt1 = np.zeros((pupils.shape[0], 3))
    pt1[:,0:2] = eprad*vig_pupils+aim_pt
    pt1[:, 2] = fod.obj_dist+fod.enp_dist
    pt0 = osp.obj_coords(fld)
    dir0 = pt1 - pt0
    length = norm(dir0)
    dir0 = dir0/np.expand_dims(length, axis=1)
    return rly_rly_rly_trace(opt_model.seq_model, pt0, dir0, wvl, **kwargs)

def rly_rly_rly_trace(seq_model, pt0, dir0, wvl, eps=1.0e-12, **kwargs): # from raytr.trace and raytr.trace_raw
    path = [v for v in seq_model.path(wvl)]
    
    #kwargs['first_surf'] = kwargs.get('first_surf', 1)
    #kwargs['last_surf'] = kwargs.get('last_surf',
    #                                 seq_model.get_num_surfaces()-2)
    
    rays = np.zeros((dir0.shape[0], len(path), 10))
    # rays[nray][path][px, py, pz, dx, dy, dz, nx, ny, nz, dist]

    #first_surf = kwargs.get('first_surf', 0)
    #last_surf = kwargs.get('last_surf', None)
    first_surf = kwargs.get('first_surf', 1)
    last_surf = kwargs.get('last_surf', seq_model.get_num_surfaces()-2)
    
    # trace object surface
    obj = path[0]
    pt0 = np.expand_dims(pt0, axis=0)

    srf_obj = obj[Intfc]
    dst_b4, pt_obj = itfc_intersect(srf_obj, pt0, dir0, z_dir=obj[Zdir])

    before = obj
    before_pt = pt_obj
    before_dir = dir0
    before_normal = itfc_normal(srf_obj, before_pt)
    tfrm_from_before = before[Tfrm]
    z_dir_before = before[Zdir]

    # loop remaining surfaces in path
    for surf, after in enumerate(path[1:]):
        try:
            #np.tensordot(t, pts, (0, 1)).T seems to work for mass matrix multiplication
            rt, t = tfrm_from_before
            #b4_pt, b4_dir = rt.dot(before_pt - t), rt.dot(before_dir)
            b4_pt = np.tensordot(rt, before_pt - t, (0, 1)).T
            b4_dir = np.tensordot(rt, before_dir, (0, 1)).T

            #pp_dst = -b4_pt.dot(b4_dir)
            pp_dst = -np.expand_dims(np.sum(b4_pt*b4_dir, axis=1), axis=1)
            pp_pt_before = b4_pt + pp_dst*b4_dir

            ifc = after[Intfc]
            z_dir_after = after[Zdir]

            # intersect ray with profile
            pp_dst_intrsct, inc_pt = itfc_intersect(ifc, pp_pt_before, b4_dir,
                                                   eps=eps, z_dir=z_dir_before)
            dst_b4 = pp_dst[:, 0] + pp_dst_intrsct
            rays[:, surf, 0:3] = before_pt
            rays[:, surf, 3:6] = before_dir
            rays[:, surf, 6:9] = before_normal
            rays[:, surf, 9] = dst_b4

            normal = itfc_normal(ifc, inc_pt)


            '''
            # if the interface has a phase element, process that first
            if hasattr(ifc, 'phase_element'):
                doe_dir, phs = phase(ifc, inc_pt, b4_dir, normal, z_dir_before,
                                     wvl, before[Indx], after[Indx])
                # the output of the phase element becomes the input for the
                #  refraction/reflection calculation
                b4_dir = doe_dir
                op_delta += phs
            '''

            # refract or reflect ray at interface
            if ifc.interact_mode == 'reflect':
                after_dir = reflect(b4_dir, normal)
            elif ifc.interact_mode == 'transmit':
                after_dir = bend(b4_dir, normal, before[Indx], after[Indx])
            elif ifc.interact_mode == 'dummy':
                after_dir = b4_dir
            else:  # no action, input becomes output
                after_dir = b4_dir

            # Per `Hopkins, 1981 <https://dx.doi.org/10.1080/713820605>`_, the
            #  propagation direction is given by the direction cosines of the
            #  ray and therefore doesn't require the use of a negated
            #  refractive index following a reflection. Thus we use the
            #  (positive) refractive indices from the seq_model.rndx array.

            before_pt = inc_pt
            before_normal = normal
            before_dir = after_dir
            z_dir_before = z_dir_after
            before = after
            tfrm_from_before = before[Tfrm]

        except TraceMissedSurfaceError as ray_miss:
            ray_miss.surf = surf+1
            ray_miss.ifc = ifc
            ray_miss.prev_tfrm = before[Tfrm]
            ray_miss.ray_pkg = rays, wvl
            raise ray_miss

        except TraceTIRError as ray_tir:
            ray_tir.surf = surf+1
            ray_tir.ifc = ifc
            ray_tir.int_pt = inc_pt
            ray_tir.ray_pkg = rays, wvl
            raise ray_tir

        except TraceEvanescentRayError as ray_evn:
            ray_evn.surf = surf+1
            ray_evn.ifc = ifc
            ray_evn.int_pt = inc_pt
            ray_evn.ray_pkg = rays, wvl
            raise ray_evn

    # lifted from the loop since it'll no longer hit the
    # StopIteration exception

    if len(path) > 1:
        rays[:, surf + 1, 0:3] = inc_pt
        rays[:, surf + 1, 3:6] = after_dir
        rays[:, surf + 1, 6:9] = normal
        rays[:, surf + 1, 9] = 0
    return rays
    
def itfc_intersect(ifc, p, d, eps=1e-12, z_dir=1):
    if isinstance(ifc, Surface):
        if isinstance(ifc.profile, Spherical):
            # copied from elem.profiles
            ax2 = ifc.profile.cv
            cx2 = ifc.profile.cv * np.sum(p*p, axis=1) - 2*p[:,2]
            b = ifc.profile.cv * np.sum(d*p, axis=1) - d[:,2]
            
            discr = b*b-ax2*cx2
            # Use z_dir to pick correct root
            if np.any(discr < 0):
                raise TraceMissedSurfaceError

            s = cx2/(z_dir*np.sqrt(discr) - b)

            p1 = p + np.expand_dims(s, axis=1)*d
            return s, p1
        else:
            raise RuntimeError("intersection not implemented for profile {}".format(ifc.profile))
    else:
        raise RuntimeError("intersection not implemented for {}".format(ifc))

def itfc_normal(ifc, pts):
    if isinstance(ifc, Surface):
        if isinstance(ifc.profile, Spherical):
            # copied from elem.profiles
            return np.stack([-ifc.profile.cv*pts[:,0], -ifc.profile.cv*pts[:,1], 1.0-ifc.profile.cv*pts[:,2]],
                           axis=-1)
        else:
            raise RuntimeError("intersection not implemented for profile {}".format(ifc.profile))
    else:
        raise RuntimeError("intersection not implemented for {}".format(ifc))
        
#copied from raytrace.reflect
def reflect(d_in, normals):
    normal_len = norm(normals)
    cosI = np.sum(d_in * normals, axis=1)/normal_len
    d_out = d_in - 2.0*cosI*normals
    return d_out

#copied from raytrace.bend
def bend(d_in, normal, n_in, n_out):
    try:
        normal_len = norm(normal)
        cosI = np.sum(d_in * normal, axis=1)/normal_len
        sinI_sqr = 1.0 - cosI*cosI
        n_cosIp = np.copysign(np.sqrt(n_out*n_out - n_in*n_in*sinI_sqr), cosI)
        alpha = np.expand_dims(n_cosIp - n_in*cosI, axis=1)
        d_out = (n_in*d_in + alpha*normal)/n_out
        return d_out
    except ValueError:
        raise TraceTIRError(d_in, normal, n_in, n_out)
In [147]:
super_rays, _ = super_trace_grid(sm, 0)
#print(super_rays[:,:,-1])

def ray_err(sm, rays, fi):
    osp = sm.opt_model.optical_spec
    fld = osp.field_of_view.fields[fi]
    foc = osp.defocus.get_focus()
    
    image_pt = fld.ref_sphere[0]
    dist = np.expand_dims(foc/rays[:,:,-1,5], axis=-1)
    defocused_pts = rays[:,:,-1,0:3] + dist*rays[:,:,-1,3:6]
    t_abr = defocused_pts - image_pt
    return np.sqrt(np.sum(t_abr*t_abr, axis=2))

def dump_dist(p, wi, ray_pkg, fld, wvl, foc):
    if ray_pkg is not None:
        image_pt = fld.ref_sphere[0]
        ray = ray_pkg[mc.ray]
        dist = foc / ray[-1][mc.d][2]
        defocused_pt = ray[-1][mc.p] + dist*ray[-1][mc.d]
        t_abr = defocused_pt - image_pt
        return np.sqrt(np.sum(t_abr*t_abr))
    
def spot_rms(sm, fld_idx=0, num_rays=21):
    return np.sqrt(np.mean(np.square(sm.trace_grid(dump_dist, fld_idx, form='list', num_rays=21, append_if_none=False)[0]), axis=1))

def spot_rms2(sm, rays=None, fld_idx=0, num_rays=21):
    if rays is None:
        rays, _ = super_trace_grid(sm, fld_idx, num_rays=21)
    return np.sqrt(np.mean(np.square(ray_err(sm, rays, fld_idx)), axis=1))

spot_rms(sm), spot_rms2(sm)
Out[147]:
(array([0.44231141, 0.4014741 , 0.38252552]),
 array([0.44778724, 0.40678566, 0.38775669]))
In [163]:
# uhh why are they different?
# let's check the rays themselves to see what they look like

def dump_rays(p, wi, ray_pkg, fld, wvl, foc):
    if ray_pkg is not None:
        ray = ray_pkg[mc.ray]
        return [ray[-1][mc.p], ray[-1][mc.d]]

sm.trace_grid(dump_rays, 0, form='list', append_if_none=False)[0][0][0:10], super_rays[0,0:10,-1]
Out[163]:
(array([[[ 8.43971916e-01,  3.75098629e-01,  0.00000000e+00],
         [ 2.36775522e-01,  1.05233565e-01,  9.65848461e-01]],
 
        [[ 7.60036422e-01,  2.53345474e-01,  0.00000000e+00],
         [ 2.33696527e-01,  7.78988423e-02,  9.69184040e-01]],
 
        [[ 7.02253533e-01,  1.56056341e-01,  0.00000000e+00],
         [ 2.31559445e-01,  5.14576544e-02,  9.71458869e-01]],
 
        [[ 6.68405774e-01,  7.42673082e-02,  0.00000000e+00],
         [ 2.30301018e-01,  2.55890020e-02,  9.72782938e-01]],
 
        [[ 6.57255825e-01,  1.01347299e-16,  0.00000000e+00],
         [ 2.29885412e-01,  3.54477885e-17,  9.73217703e-01]],
 
        [[ 6.68405774e-01, -7.42673082e-02,  0.00000000e+00],
         [ 2.30301018e-01, -2.55890020e-02,  9.72782938e-01]],
 
        [[ 7.02253533e-01, -1.56056341e-01,  0.00000000e+00],
         [ 2.31559445e-01, -5.14576544e-02,  9.71458869e-01]],
 
        [[ 7.60036422e-01, -2.53345474e-01,  0.00000000e+00],
         [ 2.33696527e-01, -7.78988423e-02,  9.69184040e-01]],
 
        [[ 8.43971916e-01, -3.75098629e-01,  0.00000000e+00],
         [ 2.36775522e-01, -1.05233565e-01,  9.65848461e-01]],
 
        [[ 6.65190535e-01,  4.15744084e-01,  0.00000000e+00],
         [ 2.07346720e-01,  1.29591700e-01,  9.69645981e-01]]]),
 array([[ 0.97899133,  0.        ,  0.        ,  0.26458622,  0.        ,
          0.96436203, -0.        , -0.        ,  1.        ,  0.        ],
        [ 0.84397192,  0.37509863,  0.        ,  0.23677552,  0.10523357,
          0.96584846, -0.        , -0.        ,  1.        ,  0.        ],
        [ 0.76003642,  0.25334547,  0.        ,  0.23369653,  0.07789884,
          0.96918404, -0.        , -0.        ,  1.        ,  0.        ],
        [ 0.70225353,  0.15605634,  0.        ,  0.23155944,  0.05145765,
          0.97145887, -0.        , -0.        ,  1.        ,  0.        ],
        [ 0.66840577,  0.07426731,  0.        ,  0.23030102,  0.025589  ,
          0.97278294, -0.        , -0.        ,  1.        ,  0.        ],
        [ 0.65725583,  0.        ,  0.        ,  0.22988541,  0.        ,
          0.9732177 , -0.        , -0.        ,  1.        ,  0.        ],
        [ 0.66840577, -0.07426731,  0.        ,  0.23030102, -0.025589  ,
          0.97278294, -0.        ,  0.        ,  1.        ,  0.        ],
        [ 0.70225353, -0.15605634,  0.        ,  0.23155944, -0.05145765,
          0.97145887, -0.        ,  0.        ,  1.        ,  0.        ],
        [ 0.76003642, -0.25334547,  0.        ,  0.23369653, -0.07789884,
          0.96918404, -0.        ,  0.        ,  1.        ,  0.        ],
        [ 0.84397192, -0.37509863,  0.        ,  0.23677552, -0.10523357,
          0.96584846, -0.        ,  0.        ,  1.        ,  0.        ]]))
In [164]:
# they look really similar...
# let's make sure they're the same size
ref_rays, _ = sm.trace_grid(dump_rays, 0, form='list', append_if_none=False)
len(ref_rays[0]), super_rays[0,:,-1].shape
Out[164]:
(311, (313, 10))
In [165]:
# so the original has five more points than super_rays...
# but which ones?

def dump_start(p, wi, ray_pkg, fld, wvl, foc):
    if ray_pkg is not None:
        image_pt = fld.ref_sphere[0]
        ray = ray_pkg[mc.ray]
        #v = ray[-1][mc.d][0:2] / ray[-1][mc.d][2]
        return [ray[0][mc.p], ray[0][mc.d]]
    
sm.trace_grid(dump_start, 0, form='list', append_if_none=False)[0][0]
Out[165]:
array([[[-0.00000000e+00, -0.00000000e+00,  0.00000000e+00],
        [-7.20000000e-10, -3.20000000e-10,  1.00000000e+00]],

       [[-0.00000000e+00, -0.00000000e+00,  0.00000000e+00],
        [-7.20000000e-10, -2.40000000e-10,  1.00000000e+00]],

       [[-0.00000000e+00, -0.00000000e+00,  0.00000000e+00],
        [-7.20000000e-10, -1.60000000e-10,  1.00000000e+00]],

       ...,

       [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 7.20000000e-10,  2.40000000e-10,  1.00000000e+00]],

       [[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 7.20000000e-10,  3.20000000e-10,  1.00000000e+00]],

       [[ 0.00000000e+00, -0.00000000e+00,  0.00000000e+00],
        [ 8.00000000e-10, -1.11022302e-25,  1.00000000e+00]]])
In [166]:
super_rays[0,:,0,0:6]
Out[166]:
array([[-0.0e+00,  0.0e+00,  0.0e+00, -8.0e-10,  0.0e+00,  1.0e+00],
       [-0.0e+00, -0.0e+00,  0.0e+00, -7.2e-10, -3.2e-10,  1.0e+00],
       [-0.0e+00, -0.0e+00,  0.0e+00, -7.2e-10, -2.4e-10,  1.0e+00],
       ...,
       [ 0.0e+00,  0.0e+00,  0.0e+00,  7.2e-10,  2.4e-10,  1.0e+00],
       [ 0.0e+00,  0.0e+00,  0.0e+00,  7.2e-10,  3.2e-10,  1.0e+00],
       [ 0.0e+00,  0.0e+00,  0.0e+00,  8.0e-10,  0.0e+00,  1.0e+00]])
In [167]:
# look at that, the last point's different
# it looks like it'd be pointed closer to the right edge of the screen than the other ones

# let's reduce the objective distance and see if the discrepancy disappears

sm.gaps[0].thi=1e9
opm.update_model()

spot_rms(sm), spot_rms2(sm)
Out[167]:
(array([0.00557354, 0.0097539 , 0.01172143]),
 array([0.00558912, 0.00978306, 0.01175699]))
In [168]:
# ugh let's see if it persists with different grid sizes

sm.gaps[0].thi=1e10
opm.update_model()

spot_rms(sm, num_rays=41), spot_rms2(sm, num_rays=41), spot_rms(sm, num_rays=61), spot_rms2(sm, num_rays=61)
Out[168]:
(array([0.44231141, 0.4014741 , 0.38252552]),
 array([0.44778724, 0.40678566, 0.38775669]),
 array([0.44231141, 0.4014741 , 0.38252552]),
 array([0.44778724, 0.40678566, 0.38775669]))
In [172]:
#  the difference is fairly small, let's see if the speedup is worth trading some accuracy

import time
ref_start = time.perf_counter_ns()
for _ in range(100):
    sm.trace_grid(dump_rays, 0, form='list', append_if_none=False)
ref_end = time.perf_counter_ns()

fast_start = time.perf_counter_ns()
for _ in range(100):
    super_trace_grid(sm, 0)
fast_end = time.perf_counter_ns()

ref_elapsed = (ref_end - ref_start)/10.
fast_elapsed = (fast_end - fast_start)/10.

ref_elapsed, fast_elapsed, ref_elapsed/fast_elapsed
Out[172]:
(2195380700.2, 65261865.4, 33.63956403550794)
In [ ]:
# ~30x speedup is probably worth a <1% error, let's just go with it