Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pathwaysutils/debug/watchdog.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _log_thread_stack(thread: threading.Thread):
_logger.debug(
"".join(
traceback.format_stack(
sys._current_frames().get( # pylint: disable=protected-access
sys._current_frames().get( # pylint: disable=protected-access # pyrefly: ignore[no-matching-overload]
thread.ident, []
)
)
Expand Down
4 changes: 2 additions & 2 deletions pathwaysutils/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def active_slice_count(self) -> int:
@property
def inactive_slice_indices(self) -> set[int]:
"""The set of inactive slice indices."""
return self.all_slice_indices - self.active_slice_indices
return self.all_slice_indices - self.active_slice_indices # pyrefly: ignore[bad-return]

def scale_by_active_slices(self, x: int | float) -> int | float:
"""Scale x by the number of active slices."""
Expand Down Expand Up @@ -363,7 +363,7 @@ def attempt_execution(attempt: int) -> Any:

attempt += 1

return wrapper
return wrapper # pyrefly: ignore[bad-return]

return decorator

Expand Down
12 changes: 6 additions & 6 deletions pathwaysutils/experimental/gke/jobset.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(
pathways_dir=pathways_dir,
num_vms=num_vms,
chips_per_vm=chips_per_vm,
gke_accel_type=gke_accel_type,
gke_accel_type=gke_accel_type, # pyrefly: ignore[bad-argument-type]
topology=topology,
image_tag=image_tag,
max_slice_restarts=max_slice_restarts,
Expand Down Expand Up @@ -306,8 +306,8 @@ def _build_head_job_template(
head_pod_spec.host_network = True
head_pod_spec.dns_policy = "ClusterFirstWithHostNet"

rm_container.restart_policy = "Always"
proxy_container.restart_policy = "Always"
rm_container.restart_policy = "Always" # pyrefly: ignore[missing-attribute]
proxy_container.restart_policy = "Always" # pyrefly: ignore[missing-attribute]

init_containers = head_pod_spec.init_containers or []
init_containers.extend([rm_container, proxy_container])
Expand Down Expand Up @@ -566,11 +566,11 @@ def _compile_config(self) -> dict[str, Any]:
},
}
if self._labels:
jobset_config["metadata"]["labels"] = self._labels
jobset_config["metadata"]["labels"] = self._labels # pyrefly: ignore[bad-assignment]
if self._annotations:
jobset_config["metadata"]["annotations"] = self._annotations
jobset_config["metadata"]["annotations"] = self._annotations # pyrefly: ignore[bad-assignment]
if self._success_policy:
jobset_config["spec"]["successPolicy"] = self._success_policy
jobset_config["spec"]["successPolicy"] = self._success_policy # pyrefly: ignore[bad-assignment]

return jobset_config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _create_metric_descriptor(
try:
self.client.create_metric_descriptor(
name=f"projects/{self.project_id}",
metric_descriptor={
metric_descriptor={ # pyrefly: ignore[bad-argument-type]
"type": metric_type,
"metric_kind": metric_kind,
"value_type": value_type,
Expand Down
2 changes: 1 addition & 1 deletion pathwaysutils/experimental/split_by_mesh_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def split_by_mesh_axis(
"Mesh axis sections must be monotonically increasing, but got"
f" {mesh_axis_sections=}."
)
mesh_axis_sections += [mesh.axis_sizes[mesh_axis_idx]]
mesh_axis_sections += [mesh.axis_sizes[mesh_axis_idx]] # pyrefly: ignore[unsupported-operation]

submeshes = []
axis_boundary_start = 0
Expand Down
6 changes: 3 additions & 3 deletions pathwaysutils/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def wrap(f: _F) -> _F:
cached = functools.lru_cache(maxsize=maxsize)(f)
wrapper = functools.wraps(f)(cached)

wrapper.cache_clear = cached.cache_clear
wrapper.cache_info = cached.cache_info
wrapper.cache_clear = cached.cache_clear # pyrefly: ignore[missing-attribute]
wrapper.cache_info = cached.cache_info # pyrefly: ignore[missing-attribute]
backend.register_backend_cache(wrapper, "Pathways LRU cache")
return wrapper
return wrapper # pyrefly: ignore[bad-return]

return wrap
12 changes: 6 additions & 6 deletions pathwaysutils/persistence/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def get_bulk_write_request(
) -> str:
"""Returns a string representation of a bulk write request, writes multiple arrays with one call."""
write_requests = [
get_write_request(location_path, name, jax_array, timeout, True)[
get_write_request(location_path, name, jax_array, timeout, True)[ # pyrefly: ignore[bad-index]
"persistenceWriteRequest"
]
for name, jax_array in zip(names, jax_arrays)
Expand All @@ -175,7 +175,7 @@ def get_read_request(
) -> str | Mapping[str, Any]:
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
devices = np.array(devices) # pyrefly: ignore[bad-assignment]

timeout_seconds, timeout_fractional_seconds = divmod(
timeout.total_seconds(), 1
Expand All @@ -190,7 +190,7 @@ def get_read_request(
sharding, len(shape)
),
"devices": {
"device_ids": [device.id for device in devices.flatten()]
"device_ids": [device.id for device in devices.flatten()] # pyrefly: ignore[missing-attribute]
},
"timeout": {
"seconds": int(timeout_seconds),
Expand All @@ -215,7 +215,7 @@ def get_bulk_read_request(
) -> str:
"""Returns a string representation of a bulk read request, reads multiple arrays with one call."""
read_requests = [
get_read_request(
get_read_request( # pyrefly: ignore[bad-index]
location_path, name, dtype, shape, sharding, devices, timeout, True
)["persistenceReadRequest"]
for name, dtype, shape, sharding in zip(names, dtypes, shapes, shardings)
Expand All @@ -233,7 +233,7 @@ def write_one_array(
):
"""Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future."""
write_request = get_write_request(location, name, value, timeout)
write_executable = plugin_executable.PluginExecutable(write_request)
write_executable = plugin_executable.PluginExecutable(write_request) # pyrefly: ignore[bad-argument-type]
_, write_future = write_executable.call([value])
return write_future

Expand Down Expand Up @@ -263,7 +263,7 @@ def read_arrays(
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""

bulk_read_request = get_bulk_read_request(
location, names, dtypes, shapes, shardings, devices, timeout
location, names, dtypes, shapes, shardings, devices, timeout # pyrefly: ignore[bad-argument-type]
)
bulk_read_executable = plugin_executable.PluginExecutable(bulk_read_request)
out_avals = [
Expand Down
10 changes: 5 additions & 5 deletions pathwaysutils/persistence/orbax_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ async def serialize(
"""Uses Pathways Persistence API to serialize a jax array."""
type_handlers.check_input_arguments(values, infos, args)

if any([arg.dtype is not None for arg in args]):
if any([arg.dtype is not None for arg in args]): # pyrefly: ignore[not-iterable]
raise ValueError("Casting during save not supported for Pathways.")

array_metadatas = []
any_random_key = False
arrays = []
for v, info, arg in zip(values, infos, args):
for v, info, arg in zip(values, infos, args): # pyrefly: ignore[bad-argument-type]
ext_metadata = None
if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key):
any_random_key = True
Expand All @@ -118,7 +118,7 @@ async def serialize(
ArrayMetadata(
param_name=info.name,
shape=v.shape,
dtype=(arg.dtype if arg is not None else v.dtype),
dtype=(arg.dtype if arg is not None else v.dtype), # pyrefly: ignore[bad-argument-type]
write_shape=getattr(v, "local_shape", v.shape),
chunk_shape=getattr(v, "local_shape", v.shape),
use_ocdbt=False,
Expand Down Expand Up @@ -187,7 +187,7 @@ async def deserialize(
global_meshes.append(arg.mesh)
mesh_axes.append(arg.mesh_axes)
shardings.append(
jax.sharding.NamedSharding(mesh=arg.mesh, spec=arg.mesh_axes)
jax.sharding.NamedSharding(mesh=arg.mesh, spec=arg.mesh_axes) # pyrefly: ignore[bad-argument-type]
)
else:
if not isinstance(arg.sharding, jax.sharding.NamedSharding):
Expand Down Expand Up @@ -246,7 +246,7 @@ async def deserialize(
grouped_arrays, read_future = helper.read_arrays(
locations[0],
names,
grouped_dtypes,
grouped_dtypes, # pyrefly: ignore[bad-argument-type]
grouped_global_shapes,
grouped_shardings,
global_mesh.devices,
Expand Down
4 changes: 2 additions & 2 deletions pathwaysutils/plugin_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class PluginExecutable:

def __init__(self, prog_str: str):
ifrt_client = jax.local_devices()[0].client
program = ifrt_programs.make_plugin_program(prog_str)
options = ifrt_programs.make_plugin_compile_options()
program = ifrt_programs.make_plugin_program(prog_str) # pyrefly: ignore[missing-attribute]
options = ifrt_programs.make_plugin_compile_options() # pyrefly: ignore[missing-attribute]
self.compiled = ifrt_client.compile_ifrt_program(program, options)

def call(
Expand Down
4 changes: 2 additions & 2 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,8 @@ def start_server_patch(port: int, requires_backend: bool = True) -> None:
)
start_server(port, requires_backend=requires_backend)

jax.profiler.start_server = start_server_patch
jax._src.profiler.start_server = start_server_patch # pylint: disable=protected-access
jax.profiler.start_server = start_server_patch # pyrefly: ignore[bad-assignment]
jax._src.profiler.start_server = start_server_patch # pylint: disable=protected-access # pyrefly: ignore[bad-assignment]

def stop_server_patch() -> None:
_logger.debug("jax.profile.stop_server patched with pathways' stop_server")
Expand Down
2 changes: 1 addition & 1 deletion pathwaysutils/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def find_intermediate_sharding(

intermediate_mesh, intermediate_spec, replicated_axes = (
_build_intermediate_mesh_and_spec(
in_sharding.mesh,
in_sharding.mesh, # pyrefly: ignore[bad-argument-type]
in_sharding.spec,
src_dims,
dst_dims,
Expand Down
2 changes: 1 addition & 1 deletion pathwaysutils/test/experimental/gke/jobset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_monkeypatch_restart_policy(self):
# Construct V1Container with restart_policy to test monkeypatch.
c = client.V1Container(
name="test",
restart_policy="Always"
restart_policy="Always" # pyrefly: ignore[unexpected-keyword]
) # pytype: disable=wrong-keyword-args
self.assertEqual(getattr(c, "restart_policy"), "Always")

Expand Down
Loading