Skip to content

Commit c34bb3c

Browse files
authored
Improve equality comparison for executors (#840)
Fix type hint for intermediate store on spec Make spec __repr__ consistent with __eq__ Implement __eq__ for all executors Unit test for executor equality Implement __repr__ for all executors
1 parent b18af43 commit c34bb3c

File tree

11 files changed

+39
-15
lines changed

11 files changed

+39
-15
lines changed

cubed/runtime/executors/beam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class BeamExecutor(DagExecutor):
8383
"""An execution engine that uses Apache Beam."""
8484

8585
def __init__(self, **kwargs):
86-
self.kwargs = kwargs
86+
super().__init__(**kwargs)
8787

8888
@property
8989
def name(self) -> str:

cubed/runtime/executors/coiled.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class CoiledExecutor(DagExecutor):
2222
"""An execution engine that uses Coiled Functions."""
2323

2424
def __init__(self, **kwargs):
25-
self.kwargs = kwargs
25+
super().__init__(**kwargs)
2626

2727
@property
2828
def name(self) -> str:

cubed/runtime/executors/dask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class DaskExecutor(DagExecutor):
5959
"""An execution engine that uses Dask Distributed's async API."""
6060

6161
def __init__(self, **kwargs):
62-
self.kwargs = kwargs
62+
super().__init__(**kwargs)
6363

6464
@property
6565
def name(self) -> str:

cubed/runtime/executors/lithops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ class LithopsExecutor(DagExecutor):
257257
"""An execution engine that uses Lithops."""
258258

259259
def __init__(self, **kwargs):
260-
self.kwargs = kwargs
260+
super().__init__(**kwargs)
261261

262262
@property
263263
def name(self) -> str:

cubed/runtime/executors/local.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class ThreadsExecutor(DagExecutor):
114114
"""An execution engine that uses Python asyncio."""
115115

116116
def __init__(self, **kwargs):
117-
self.kwargs = kwargs
117+
super().__init__(**kwargs)
118118

119119
# Tell NumPy to use a single thread
120120
# from https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy
@@ -204,7 +204,7 @@ class ProcessesExecutor(DagExecutor):
204204
"""An execution engine that uses local processes."""
205205

206206
def __init__(self, **kwargs):
207-
self.kwargs = kwargs
207+
super().__init__(**kwargs)
208208

209209
# Tell NumPy to use a single thread
210210
# from https://stackoverflow.com/questions/30791550/limit-number-of-threads-in-numpy

cubed/runtime/executors/modal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class ModalExecutor(DagExecutor):
141141
"""An execution engine that uses Modal's async API."""
142142

143143
def __init__(self, **kwargs):
144-
self.kwargs = kwargs
144+
super().__init__(**kwargs)
145145

146146
@property
147147
def name(self) -> str:

cubed/runtime/executors/ray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class RayExecutor(DagExecutor):
1515
"""An execution engine that uses Ray."""
1616

1717
def __init__(self, **kwargs):
18-
self.kwargs = kwargs
18+
super().__init__(**kwargs)
1919

2020
@property
2121
def name(self) -> str:

cubed/runtime/executors/spark.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class SparkExecutor(DagExecutor):
2020
MIN_MEMORY_MiB = 512
2121

2222
def __init__(self, **kwargs):
23-
self._callbacks = None
2423
super().__init__(**kwargs)
2524

2625
@property
@@ -57,9 +56,6 @@ def execute_dag(
5756
compute_id: Optional[str] = None,
5857
**kwargs: Any,
5958
):
60-
# Store callbacks for later use during computation
61-
self._callbacks = callbacks
62-
6359
# Configure Spark memory settings from Spec if provided
6460
spark_builder = SparkSession.builder
6561
if spec is not None and hasattr(spec, "allowed_mem") and spec.allowed_mem:

cubed/runtime/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88

99

1010
class DagExecutor:
11+
def __init__(self, **kwargs):
12+
self.kwargs = kwargs
13+
14+
def __eq__(self, other):
15+
if isinstance(other, DagExecutor):
16+
return self.name == other.name and self.kwargs == other.kwargs
17+
else:
18+
return False
19+
20+
def __repr__(self) -> str:
21+
return f"{self.__class__.__name__}(kwargs={self.kwargs})"
22+
1123
@property
1224
def name(self) -> str:
1325
raise NotImplementedError # pragma: no cover

cubed/spec.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,14 @@ def zarr_compressor(self) -> Union[dict, str, None]:
125125
return self._zarr_compressor
126126

127127
@property
128-
def intermediate_store(self) -> Union[dict, str, None]:
128+
def intermediate_store(self) -> Union[T_Store, None]:
129129
"""The Zarr store for intermediate data. Takes precedence over ``work_dir``."""
130130
return self._intermediate_store
131131

132132
def __repr__(self) -> str:
133133
return (
134-
f"cubed.Spec(work_dir={self._work_dir}, intermediate_store={self._intermediate_store}, allowed_mem={self._allowed_mem}, "
135-
f"reserved_mem={self._reserved_mem}, executor={self._executor}, storage_options={self._storage_options}, zarr_compressor={self._zarr_compressor})"
134+
f"cubed.Spec(work_dir={self.work_dir}, intermediate_store={self.intermediate_store}, allowed_mem={self.allowed_mem}, "
135+
f"reserved_mem={self.reserved_mem}, executor={self.executor}, storage_options={self.storage_options}, zarr_compressor={self.zarr_compressor})"
136136
)
137137

138138
def __eq__(self, other):

0 commit comments

Comments
 (0)