|
18 | 18 | import functools |
19 | 19 | from typing import Optional, Sequence |
20 | 20 |
|
21 | | -import chex |
22 | 21 | import jax |
23 | 22 | import jax.numpy as jnp |
24 | 23 | import jax.scipy.stats.norm as multivariate_normal |
@@ -57,8 +56,8 @@ def canonicalize_key(key_or_seed: jax.Array | int) -> jax.Array: |
57 | 56 | warn_deprecated_function, replacement='optax.tree.cast' |
58 | 57 | ) |
59 | 58 | def cast_tree( |
60 | | - tree: chex.ArrayTree, dtype: Optional[jax.typing.DTypeLike] |
61 | | -) -> chex.ArrayTree: |
| 59 | + tree: base.ArrayTree, dtype: Optional[jax.typing.DTypeLike] |
| 60 | +) -> base.ArrayTree: |
62 | 61 | return optax.tree.cast(tree, dtype) |
63 | 62 |
|
64 | 63 |
|
@@ -171,29 +170,29 @@ def multi_normal( |
171 | 170 |
|
172 | 171 | @jax.custom_vjp |
173 | 172 | def _scale_gradient( |
174 | | - inputs: chex.ArrayTree, scale: jax.typing.ArrayLike) -> chex.ArrayTree: |
| 173 | + inputs: base.ArrayTree, scale: jax.typing.ArrayLike) -> base.ArrayTree: |
175 | 174 | """Internal gradient scaling implementation.""" |
176 | 175 | del scale # Only used for the backward pass defined in _scale_gradient_bwd. |
177 | 176 | return inputs |
178 | 177 |
|
179 | 178 |
|
180 | 179 | def _scale_gradient_fwd( |
181 | | - inputs: chex.ArrayTree, scale: jax.typing.ArrayLike |
182 | | -) -> tuple[chex.ArrayTree, jax.typing.ArrayLike]: |
| 180 | + inputs: base.ArrayTree, scale: jax.typing.ArrayLike |
| 181 | +) -> tuple[base.ArrayTree, jax.typing.ArrayLike]: |
183 | 182 | return _scale_gradient(inputs, scale), scale |
184 | 183 |
|
185 | 184 |
|
186 | 185 | def _scale_gradient_bwd( |
187 | | - scale: jax.typing.ArrayLike, g: chex.ArrayTree |
188 | | -) -> tuple[chex.ArrayTree, None]: |
| 186 | + scale: jax.typing.ArrayLike, g: base.ArrayTree |
| 187 | +) -> tuple[base.ArrayTree, None]: |
189 | 188 | return (jax.tree.map(lambda g_: g_ * scale, g), None) |
190 | 189 |
|
191 | 190 |
|
192 | 191 | _scale_gradient.defvjp(_scale_gradient_fwd, _scale_gradient_bwd) |
193 | 192 |
|
194 | 193 |
|
195 | 194 | def scale_gradient( |
196 | | - inputs: chex.ArrayTree, scale: jax.typing.ArrayLike) -> chex.ArrayTree: |
| 195 | + inputs: base.ArrayTree, scale: jax.typing.ArrayLike) -> base.ArrayTree: |
197 | 196 | """Scales gradients for the backwards pass. |
198 | 197 |
|
199 | 198 | Args: |
|
0 commit comments