diff --git a/pathwaysutils/debug/watchdog.py b/pathwaysutils/debug/watchdog.py index 080c158..07b4af9 100644 --- a/pathwaysutils/debug/watchdog.py +++ b/pathwaysutils/debug/watchdog.py @@ -34,7 +34,7 @@ def _log_thread_stack(thread: threading.Thread): _logger.debug( "".join( traceback.format_stack( - sys._current_frames().get( # pylint: disable=protected-access + sys._current_frames().get( # pylint: disable=protected-access # pyrefly: ignore[no-matching-overload] thread.ident, [] ) ) diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index ccb8927..0345bf6 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -138,7 +138,7 @@ def active_slice_count(self) -> int: @property def inactive_slice_indices(self) -> set[int]: """The set of inactive slice indices.""" - return self.all_slice_indices - self.active_slice_indices + return self.all_slice_indices - self.active_slice_indices # pyrefly: ignore[bad-return] def scale_by_active_slices(self, x: int | float) -> int | float: """Scale x by the number of active slices.""" @@ -363,7 +363,7 @@ def attempt_execution(attempt: int) -> Any: attempt += 1 - return wrapper + return wrapper # pyrefly: ignore[bad-return] return decorator diff --git a/pathwaysutils/experimental/gke/jobset.py b/pathwaysutils/experimental/gke/jobset.py index 6cedb82..b353bac 100644 --- a/pathwaysutils/experimental/gke/jobset.py +++ b/pathwaysutils/experimental/gke/jobset.py @@ -164,7 +164,7 @@ def __init__( pathways_dir=pathways_dir, num_vms=num_vms, chips_per_vm=chips_per_vm, - gke_accel_type=gke_accel_type, + gke_accel_type=gke_accel_type, # pyrefly: ignore[bad-argument-type] topology=topology, image_tag=image_tag, max_slice_restarts=max_slice_restarts, @@ -306,8 +306,8 @@ def _build_head_job_template( head_pod_spec.host_network = True head_pod_spec.dns_policy = "ClusterFirstWithHostNet" - rm_container.restart_policy = "Always" - proxy_container.restart_policy = "Always" + rm_container.restart_policy = "Always" # pyrefly: ignore[missing-attribute] + proxy_container.restart_policy = "Always" # pyrefly: ignore[missing-attribute] init_containers = head_pod_spec.init_containers or [] init_containers.extend([rm_container, proxy_container]) @@ -566,11 +566,11 @@ def _compile_config(self) -> dict[str, Any]: }, } if self._labels: - jobset_config["metadata"]["labels"] = self._labels + jobset_config["metadata"]["labels"] = self._labels # pyrefly: ignore[bad-assignment] if self._annotations: - jobset_config["metadata"]["annotations"] = self._annotations + jobset_config["metadata"]["annotations"] = self._annotations # pyrefly: ignore[bad-assignment] if self._success_policy: - jobset_config["spec"]["successPolicy"] = self._success_policy + jobset_config["spec"]["successPolicy"] = self._success_policy # pyrefly: ignore[bad-assignment] return jobset_config diff --git a/pathwaysutils/experimental/shared_pathways_service/metrics_collector.py b/pathwaysutils/experimental/shared_pathways_service/metrics_collector.py index b37caec..01939f9 100644 --- a/pathwaysutils/experimental/shared_pathways_service/metrics_collector.py +++ b/pathwaysutils/experimental/shared_pathways_service/metrics_collector.py @@ -247,7 +247,7 @@ def _create_metric_descriptor( try: self.client.create_metric_descriptor( name=f"projects/{self.project_id}", - metric_descriptor={ + metric_descriptor={ # pyrefly: ignore[bad-argument-type] "type": metric_type, "metric_kind": metric_kind, "value_type": value_type, diff --git a/pathwaysutils/experimental/split_by_mesh_axis.py b/pathwaysutils/experimental/split_by_mesh_axis.py index 88e0e97..14b7b8a 100644 --- a/pathwaysutils/experimental/split_by_mesh_axis.py +++ b/pathwaysutils/experimental/split_by_mesh_axis.py @@ -142,7 +142,7 @@ def split_by_mesh_axis( "Mesh axis sections must be monotonically increasing, but got" f" {mesh_axis_sections=}." ) - mesh_axis_sections += [mesh.axis_sizes[mesh_axis_idx]] + mesh_axis_sections += [mesh.axis_sizes[mesh_axis_idx]] # pyrefly: ignore[unsupported-operation] submeshes = [] axis_boundary_start = 0 diff --git a/pathwaysutils/lru_cache.py b/pathwaysutils/lru_cache.py index 13cfb8a..66dc4e5 100644 --- a/pathwaysutils/lru_cache.py +++ b/pathwaysutils/lru_cache.py @@ -40,9 +40,9 @@ def wrap(f: _F) -> _F: cached = functools.lru_cache(maxsize=maxsize)(f) wrapper = functools.wraps(f)(cached) - wrapper.cache_clear = cached.cache_clear - wrapper.cache_info = cached.cache_info + wrapper.cache_clear = cached.cache_clear # pyrefly: ignore[missing-attribute] + wrapper.cache_info = cached.cache_info # pyrefly: ignore[missing-attribute] backend.register_backend_cache(wrapper, "Pathways LRU cache") - return wrapper + return wrapper # pyrefly: ignore[bad-return] return wrap diff --git a/pathwaysutils/persistence/helper.py b/pathwaysutils/persistence/helper.py index ffbe1e0..915e669 100644 --- a/pathwaysutils/persistence/helper.py +++ b/pathwaysutils/persistence/helper.py @@ -153,7 +153,7 @@ def get_bulk_write_request( ) -> str: """Returns a string representation of a bulk write request, writes multiple arrays with one call.""" write_requests = [ - get_write_request(location_path, name, jax_array, timeout, True)[ + get_write_request(location_path, name, jax_array, timeout, True)[ # pyrefly: ignore[bad-index] "persistenceWriteRequest" ] for name, jax_array in zip(names, jax_arrays) @@ -175,7 +175,7 @@ def get_read_request( ) -> str | Mapping[str, Any]: """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding.""" if not isinstance(devices, np.ndarray): - devices = np.array(devices) + devices = np.array(devices) # pyrefly: ignore[bad-assignment] timeout_seconds, timeout_fractional_seconds = divmod( timeout.total_seconds(), 1 @@ -190,7 +190,7 @@ def get_read_request( sharding, len(shape) ), "devices": { - "device_ids": [device.id for device in devices.flatten()] + "device_ids": [device.id for device in devices.flatten()] # pyrefly: ignore[missing-attribute] }, "timeout": { "seconds": int(timeout_seconds), @@ -215,7 +215,7 @@ def get_bulk_read_request( ) -> str: """Returns a string representation of a bulk read request, reads multiple arrays with one call.""" read_requests = [ - get_read_request( + get_read_request( # pyrefly: ignore[bad-index] location_path, name, dtype, shape, sharding, devices, timeout, True )["persistenceReadRequest"] for name, dtype, shape, sharding in zip(names, dtypes, shapes, shardings) @@ -233,7 +233,7 @@ def write_one_array( ): """Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future.""" write_request = get_write_request(location, name, value, timeout) - write_executable = plugin_executable.PluginExecutable(write_request) + write_executable = plugin_executable.PluginExecutable(write_request) # pyrefly: ignore[bad-argument-type] _, write_future = write_executable.call([value]) return write_future @@ -263,7 +263,7 @@ def read_arrays( """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result.""" bulk_read_request = get_bulk_read_request( - location, names, dtypes, shapes, shardings, devices, timeout + location, names, dtypes, shapes, shardings, devices, timeout # pyrefly: ignore[bad-argument-type] ) bulk_read_executable = plugin_executable.PluginExecutable(bulk_read_request) out_avals = [ diff --git a/pathwaysutils/persistence/orbax_handler.py b/pathwaysutils/persistence/orbax_handler.py index 7417d12..12a43e3 100644 --- a/pathwaysutils/persistence/orbax_handler.py +++ b/pathwaysutils/persistence/orbax_handler.py @@ -100,13 +100,13 @@ async def serialize( """Uses Pathways Persistence API to serialize a jax array.""" type_handlers.check_input_arguments(values, infos, args) - if any([arg.dtype is not None for arg in args]): + if any([arg.dtype is not None for arg in args]): # pyrefly: ignore[not-iterable] raise ValueError("Casting during save not supported for Pathways.") array_metadatas = [] any_random_key = False arrays = [] - for v, info, arg in zip(values, infos, args): + for v, info, arg in zip(values, infos, args): # pyrefly: ignore[bad-argument-type] ext_metadata = None if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key): any_random_key = True @@ -118,7 +118,7 @@ async def serialize( ArrayMetadata( param_name=info.name, shape=v.shape, - dtype=(arg.dtype if arg is not None else v.dtype), + dtype=(arg.dtype if arg is not None else v.dtype), # pyrefly: ignore[bad-argument-type] write_shape=getattr(v, "local_shape", v.shape), chunk_shape=getattr(v, "local_shape", v.shape), use_ocdbt=False, @@ -187,7 +187,7 @@ async def deserialize( global_meshes.append(arg.mesh) mesh_axes.append(arg.mesh_axes) shardings.append( - jax.sharding.NamedSharding(mesh=arg.mesh, spec=arg.mesh_axes) + jax.sharding.NamedSharding(mesh=arg.mesh, spec=arg.mesh_axes) # pyrefly: ignore[bad-argument-type] ) else: if not isinstance(arg.sharding, jax.sharding.NamedSharding): @@ -246,7 +246,7 @@ async def deserialize( grouped_arrays, read_future = helper.read_arrays( locations[0], names, - grouped_dtypes, + grouped_dtypes, # pyrefly: ignore[bad-argument-type] grouped_global_shapes, grouped_shardings, global_mesh.devices, diff --git a/pathwaysutils/plugin_executable.py b/pathwaysutils/plugin_executable.py index e1c8956..51a1b5f 100644 --- a/pathwaysutils/plugin_executable.py +++ b/pathwaysutils/plugin_executable.py @@ -29,8 +29,8 @@ class PluginExecutable: def __init__(self, prog_str: str): ifrt_client = jax.local_devices()[0].client - program = ifrt_programs.make_plugin_program(prog_str) - options = ifrt_programs.make_plugin_compile_options() + program = ifrt_programs.make_plugin_program(prog_str) # pyrefly: ignore[missing-attribute] + options = ifrt_programs.make_plugin_compile_options() # pyrefly: ignore[missing-attribute] self.compiled = ifrt_client.compile_ifrt_program(program, options) def call( diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index 0b5b87e..2a39b85 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -463,8 +463,8 @@ def start_server_patch(port: int, requires_backend: bool = True) -> None: ) start_server(port, requires_backend=requires_backend) - jax.profiler.start_server = start_server_patch - jax._src.profiler.start_server = start_server_patch # pylint: disable=protected-access + jax.profiler.start_server = start_server_patch # pyrefly: ignore[bad-assignment] + jax._src.profiler.start_server = start_server_patch # pylint: disable=protected-access # pyrefly: ignore[bad-assignment] def stop_server_patch() -> None: _logger.debug("jax.profile.stop_server patched with pathways' stop_server") diff --git a/pathwaysutils/reshard.py b/pathwaysutils/reshard.py index 6891ae9..64bf117 100644 --- a/pathwaysutils/reshard.py +++ b/pathwaysutils/reshard.py @@ -368,7 +368,7 @@ def find_intermediate_sharding( intermediate_mesh, intermediate_spec, replicated_axes = ( _build_intermediate_mesh_and_spec( - in_sharding.mesh, + in_sharding.mesh, # pyrefly: ignore[bad-argument-type] in_sharding.spec, src_dims, dst_dims, diff --git a/pathwaysutils/test/experimental/gke/jobset_test.py b/pathwaysutils/test/experimental/gke/jobset_test.py index 9057a90..671ad74 100644 --- a/pathwaysutils/test/experimental/gke/jobset_test.py +++ b/pathwaysutils/test/experimental/gke/jobset_test.py @@ -129,7 +129,7 @@ def test_monkeypatch_restart_policy(self): # Construct V1Container with restart_policy to test monkeypatch. c = client.V1Container( name="test", - restart_policy="Always" + restart_policy="Always" # pyrefly: ignore[unexpected-keyword] ) # pytype: disable=wrong-keyword-args self.assertEqual(getattr(c, "restart_policy"), "Always")