Skip to content

Commit af59104

Browse files
vrouletOptaxDev
authored andcommitted
Add an internal definition of ArrayTree and use it instead of chex.ArrayTree.
PiperOrigin-RevId: 831123754
1 parent daecb91 commit af59104

File tree

20 files changed

+75
-89
lines changed

20 files changed

+75
-89
lines changed

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _recursive_add_annotations_import():
125125
'base.Updates': 'optax.Updates',
126126
'base.OptState': 'optax.OptState',
127127
'base.PyTree': 'optax.PyTree',
128-
'chex.ArrayTree': 'chex.ArrayTree',
128+
'base.ArrayTree': 'optax.ArrayTree',
129129
'jax.typing.ArrayLike': 'jax.typing.ArrayLike'
130130
}
131131

optax/_src/alias_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from absl.testing import absltest
2222
from absl.testing import parameterized
23-
import chex
2423
import jax
2524
from jax import flatten_util
2625
import jax.numpy as jnp
@@ -433,11 +432,11 @@ def test_gradient_accumulation(self, opt_name, opt_kwargs, dtype):
433432

434433
def _run_opt(
435434
opt: base.GradientTransformationExtraArgs,
436-
fun: Callable[[chex.ArrayTree], jnp.ndarray],
437-
init_params: chex.ArrayTree,
435+
fun: Callable[[base.ArrayTree], jnp.ndarray],
436+
init_params: base.ArrayTree,
438437
maxiter: int = 500,
439438
tol: float = 1e-3,
440-
) -> tuple[chex.ArrayTree, base.OptState]:
439+
) -> tuple[base.ArrayTree, base.OptState]:
441440
"""Run LBFGS solver by iterative calls to grad transform and apply_updates."""
442441
value_and_grad_fun = jax.value_and_grad(fun)
443442

optax/_src/base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
"""Base interfaces and datatypes."""
1616

1717
from collections.abc import Callable
18-
from typing import (Any, NamedTuple, Optional, Protocol, Sequence, Union,
19-
runtime_checkable)
18+
from typing import (Any, Iterable, Mapping, NamedTuple, Optional, Protocol,
19+
Sequence, Union, runtime_checkable)
2020

21-
import chex
2221
import jax
2322
import jax.numpy as jnp
2423

@@ -30,9 +29,11 @@
3029
PyTree = Any
3130
Shape = Sequence[int]
3231
PRNGKey = jax.Array
32+
ArrayTree = Union[
33+
jax.typing.ArrayLike, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
3334

34-
OptState = chex.ArrayTree # States are arbitrary nests of `jnp.ndarrays`.
35-
Params = chex.ArrayTree # Parameters are arbitrary nests of `jnp.ndarrays`.
35+
OptState = ArrayTree # States are arbitrary nests of `jnp.ndarrays`.
36+
Params = ArrayTree # Parameters are arbitrary nests of `jnp.ndarrays`.
3637
Updates = Params # Gradient updates are of the same type as parameters.
3738

3839
Schedule = Callable[[jax.typing.ArrayLike], jax.typing.ArrayLike]

optax/_src/factorized.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import dataclasses
1919
from typing import NamedTuple, Optional
2020

21-
import chex
2221
import jax
2322
import jax.numpy as jnp
2423
import numpy as np
@@ -81,9 +80,9 @@ class FactoredState(NamedTuple):
8180
"""Overall state of the gradient transformation."""
8281

8382
count: jax.typing.ArrayLike # number of update steps.
84-
v_row: chex.ArrayTree # Tree of factored params.
85-
v_col: chex.ArrayTree # Tree of factored params.
86-
v: chex.ArrayTree # Tree for params where factoring is skipped.
83+
v_row: base.ArrayTree # Tree of factored params.
84+
v_col: base.ArrayTree # Tree of factored params.
85+
v: base.ArrayTree # Tree for params where factoring is skipped.
8786

8887

8988
def scale_by_factored_rms(

optax/_src/linear_algebra.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Optional, Union
2020
import warnings
2121

22-
import chex
2322
import jax
2423
from jax import lax
2524
import jax.numpy as jnp
@@ -61,14 +60,14 @@ def _power_iteration_cond_fun(error_tolerance, num_iters, loop_vars):
6160

6261
def power_iteration(
6362
matrix: Union[
64-
jax.typing.ArrayLike, Callable[[chex.ArrayTree], chex.ArrayTree]],
63+
jax.typing.ArrayLike, Callable[[base.ArrayTree], base.ArrayTree]],
6564
*,
66-
v0: Optional[chex.ArrayTree] = None,
65+
v0: Optional[base.ArrayTree] = None,
6766
num_iters: jax.typing.ArrayLike = 100,
6867
error_tolerance: jax.typing.ArrayLike = 1e-6,
6968
precision: lax.Precision = lax.Precision.HIGHEST,
7069
key: Optional[base.PRNGKey] = None,
71-
) -> tuple[jax.typing.ArrayLike, chex.ArrayTree]:
70+
) -> tuple[jax.typing.ArrayLike, base.ArrayTree]:
7271
r"""Power iteration algorithm.
7372
7473
This algorithm computes the dominant eigenvalue (i.e. the spectral radius) and

optax/_src/transform.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import functools
1818
from typing import NamedTuple, Optional
1919

20-
import chex
2120
import jax
2221
from jax import nn
2322
import jax.numpy as jnp
@@ -1504,15 +1503,15 @@ class ScaleByLBFGSState(NamedTuple):
15041503
count: jax.typing.ArrayLike
15051504
params: base.Params
15061505
updates: base.Params
1507-
diff_params_memory: chex.ArrayTree
1508-
diff_updates_memory: chex.ArrayTree
1506+
diff_params_memory: base.ArrayTree
1507+
diff_updates_memory: base.ArrayTree
15091508
weights_memory: jax.typing.ArrayLike
15101509

15111510

15121511
def _precondition_by_lbfgs(
15131512
updates: base.Updates,
1514-
diff_params_memory: chex.ArrayTree,
1515-
diff_updates_memory: chex.ArrayTree,
1513+
diff_params_memory: base.ArrayTree,
1514+
diff_updates_memory: base.ArrayTree,
15161515
weights_memory: jax.typing.ArrayLike,
15171516
identity_scale: jax.typing.ArrayLike, # float
15181517
memory_idx: jax.typing.ArrayLike, # int
@@ -1822,8 +1821,8 @@ def update_fn(
18221821
warn_deprecated_function, replacement='optax.tree.cast'
18231822
)
18241823
def cast_tree(
1825-
tree: chex.ArrayTree, dtype: Optional[jax.typing.DTypeLike]
1826-
) -> chex.ArrayTree:
1824+
tree: base.ArrayTree, dtype: Optional[jax.typing.DTypeLike]
1825+
) -> base.ArrayTree:
18271826
return optax.tree.cast(tree, dtype)
18281827

18291828

optax/_src/utils.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import functools
1919
from typing import Optional, Sequence
2020

21-
import chex
2221
import jax
2322
import jax.numpy as jnp
2423
import jax.scipy.stats.norm as multivariate_normal
@@ -57,8 +56,8 @@ def canonicalize_key(key_or_seed: jax.Array | int) -> jax.Array:
5756
warn_deprecated_function, replacement='optax.tree.cast'
5857
)
5958
def cast_tree(
60-
tree: chex.ArrayTree, dtype: Optional[jax.typing.DTypeLike]
61-
) -> chex.ArrayTree:
59+
tree: base.ArrayTree, dtype: Optional[jax.typing.DTypeLike]
60+
) -> base.ArrayTree:
6261
return optax.tree.cast(tree, dtype)
6362

6463

@@ -171,29 +170,29 @@ def multi_normal(
171170

172171
@jax.custom_vjp
173172
def _scale_gradient(
174-
inputs: chex.ArrayTree, scale: jax.typing.ArrayLike) -> chex.ArrayTree:
173+
inputs: base.ArrayTree, scale: jax.typing.ArrayLike) -> base.ArrayTree:
175174
"""Internal gradient scaling implementation."""
176175
del scale # Only used for the backward pass defined in _scale_gradient_bwd.
177176
return inputs
178177

179178

180179
def _scale_gradient_fwd(
181-
inputs: chex.ArrayTree, scale: jax.typing.ArrayLike
182-
) -> tuple[chex.ArrayTree, jax.typing.ArrayLike]:
180+
inputs: base.ArrayTree, scale: jax.typing.ArrayLike
181+
) -> tuple[base.ArrayTree, jax.typing.ArrayLike]:
183182
return _scale_gradient(inputs, scale), scale
184183

185184

186185
def _scale_gradient_bwd(
187-
scale: jax.typing.ArrayLike, g: chex.ArrayTree
188-
) -> tuple[chex.ArrayTree, None]:
186+
scale: jax.typing.ArrayLike, g: base.ArrayTree
187+
) -> tuple[base.ArrayTree, None]:
189188
return (jax.tree.map(lambda g_: g_ * scale, g), None)
190189

191190

192191
_scale_gradient.defvjp(_scale_gradient_fwd, _scale_gradient_bwd)
193192

194193

195194
def scale_gradient(
196-
inputs: chex.ArrayTree, scale: jax.typing.ArrayLike) -> chex.ArrayTree:
195+
inputs: base.ArrayTree, scale: jax.typing.ArrayLike) -> base.ArrayTree:
197196
"""Scales gradients for the backwards pass.
198197
199198
Args:

optax/contrib/_dog.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from collections.abc import Callable
2626
from typing import Any, NamedTuple, Optional, Union, Literal
2727

28-
import chex
2928
import jax
3029
import jax.numpy as jnp
3130
from optax._src import base
@@ -38,7 +37,7 @@ class DoGState(NamedTuple):
3837
"""State for DoG optimizer."""
3938

4039
is_init_step: jax.Array # bool
41-
init_params: chex.ArrayTree
40+
init_params: base.ArrayTree
4241
max_dist: jax.Array
4342
sum_sq_norm_grads: jax.Array
4443

@@ -220,7 +219,7 @@ def dog(
220219
class DoWGState(NamedTuple):
221220
"""State for DoWG optimizer."""
222221

223-
init_params: chex.ArrayTree
222+
init_params: base.ArrayTree
224223
weighted_sq_norm_grads: jax.Array
225224
estim_sq_dist: jax.Array
226225

optax/experimental/_aggregating.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import math
1818
from typing import Any, NamedTuple, Protocol, Sequence
1919

20-
import chex
2120
import jax
2221
import jax.numpy as jnp
2322
from optax import tree
@@ -26,8 +25,8 @@
2625
from optax.transforms import _combining
2726

2827

29-
PerElementUpdates = chex.ArrayTree
30-
AggregatedUpdates = chex.ArrayTree
28+
PerElementUpdates = base.ArrayTree
29+
AggregatedUpdates = base.ArrayTree
3130
MaybeAxis = int | Sequence[int] | None
3231

3332

optax/experimental/_microbatching.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
import functools
2121
from typing import Any, Callable, Sequence, TypeAlias
2222

23-
import chex
2423
import jax
2524
import jax.numpy as jnp
25+
from optax._src import base
2626

2727

2828
AccumulatorTree: TypeAlias = Any
@@ -53,10 +53,10 @@ class Accumulator:
5353
per-microbatch values into a single value. Used by `gvmap`.
5454
"""
5555

56-
init: Callable[[chex.ArrayTree], chex.ArrayTree]
57-
update: Callable[[chex.ArrayTree, chex.ArrayTree, int], chex.ArrayTree]
58-
finalize: Callable[[chex.ArrayTree], chex.ArrayTree]
59-
aggregate: Callable[[chex.ArrayTree], chex.ArrayTree]
56+
init: Callable[[base.ArrayTree], base.ArrayTree]
57+
update: Callable[[base.ArrayTree, base.ArrayTree, int], base.ArrayTree]
58+
finalize: Callable[[base.ArrayTree], base.ArrayTree]
59+
aggregate: Callable[[base.ArrayTree], base.ArrayTree]
6060

6161

6262
def _with_floating_check(fn: Callable[..., Any]) -> Callable[..., Any]:

0 commit comments

Comments
 (0)