ditk.distributed.env

Distributed training utilities for PyTorch.

This module provides utility functions to handle distributed training scenarios in PyTorch. It offers convenient methods to check distributed status, get process information, and determine the main process. The functions gracefully handle both distributed and non-distributed environments.

Example::
>>> # Check if distributed training is active
>>> if is_distributed():
...     print(f"Running on rank {get_rank()} of {get_world_size()}")
>>> 
>>> # Execute code only on main process
>>> if is_main_process():
...     print("This runs only on the main process")

is_distributed

ditk.distributed.env.is_distributed() bool[source]

Check if distributed training is available and initialized.

This function verifies whether PyTorch distributed training is both available (compiled with distributed support) and properly initialized. It handles cases where PyTorch or its distributed module might not be installed.

Returns:

True if distributed training is available and initialized, False otherwise.

Return type:

bool

Example::
>>> if is_distributed():
...     print("Distributed training is active")
... else:
...     print("Running in single-process mode")

get_rank

ditk.distributed.env.get_rank() int[source]

Get the global rank of the current process.

Returns the global rank (process ID) of the current process in distributed training. In non-distributed environments, this function returns 0, making it safe to use in both distributed and single-process scenarios.

Returns:

Global rank of the current process. Returns 0 if distributed training is not active.

Return type:

int

Example::
>>> rank = get_rank()
>>> print(f"Current process rank: {rank}")

get_world_size

ditk.distributed.env.get_world_size() int[source]

Get the total number of processes across all nodes.

Returns the total number of processes participating in distributed training. In non-distributed environments, this function returns 1, ensuring consistent behavior across different training setups.

Returns:

Total number of processes in the distributed training. Returns 1 if distributed training is not active.

Return type:

int

Example::
>>> world_size = get_world_size()
>>> print(f"Total number of processes: {world_size}")

is_main_process

ditk.distributed.env.is_main_process() bool[source]

Check if the current process is the main process (global rank 0).

This function is useful for executing code that should only run once across all processes, such as logging, saving checkpoints, or printing progress. In non-distributed environments, it always returns True.

Returns:

True if current process is main process (rank 0) or if distributed is not available.

Return type:

bool

Example::
>>> if is_main_process():
...     print("Saving model checkpoint...")
...     # Save checkpoint logic here