Source code for atomcloud.functions.jax_funcs.spence

[docs]def create_spence(npt): """ Create a function that calculates the spence function using either scipy.special.spence or a jaxified version of the same function. I ported the Cephes library algorithm to jax, but which is equivalent for real numbers. Args: npt: numpy or jax Returns: spence function """ def _spence_poly(w): A = npt.array( [ 4.65128586073990045278e-5, 7.31589045238094711071e-3, 1.33847639578309018650e-1, 8.79691311754530315341e-1, 2.71149851196553469920e0, 4.25697156008121755724e0, 3.29771340985225106936e0, 1.00000000000000000126e0, ] ) B = npt.array( [ 6.90990488912553276999e-4, 2.54043763932544379113e-2, 2.82974860602568089943e-1, 1.41172597751831069617e0, 3.63800533345137075418e0, 5.03278880143316990390e0, 3.54771340985225096217e0, 9.99999999999999998740e-1, ] ) return -w * npt.polyval(A, w) / npt.polyval(B, w) def _spence_calc(x): x2_bool = x > 2.0 x = npt.piecewise(x, [x2_bool], [lambda x: 1.0 / x, lambda x: x]) x1_5_bool = x > 1.5 x_5_bool = x < 0.5 x2_bool = x2_bool | x1_5_bool w_conds = [x1_5_bool, x_5_bool] w_funcs = [lambda x: 1.0 / x - 1.0, lambda x: -x, lambda x: x - 1.0] w = npt.piecewise(x, w_conds, w_funcs) y = _spence_poly(w) y_flag_one = npt.pi**2 / 6.0 - npt.log(x) * npt.log(1.0 - x) - y y = npt.where(x_5_bool, y_flag_one, y) y_flag_two = -0.5 * npt.log(x) ** 2 - y y = npt.where(x2_bool, y_flag_two, y) return y def spence(x): condlist = [x < 0.0, x == 1.0, x == 0.0] funclist = [npt.nan, 0, npt.pi**2 / 6, _spence_calc] return npt.piecewise(x, condlist, funclist) return spence
[docs]def create_polylog2d(npt): spence = create_spence(npt) def polylog2d(z): return spence(1 - z) return polylog2d