Represents a Mesh configuration over a certain list of Mesh Dimensions.
tf.experimental.dtensor.Mesh(
dim_names: List[str],
global_device_ids: np.ndarray,
local_device_ids: List[int],
local_devices: List[tf.compat.v1.DeviceSpec],
mesh_name: str = '',
global_devices: Optional[List[tf_device.DeviceSpec]] = None
)
A mesh consists of named dimensions with sizes, which describe how a set of devices are arranged. Defining tensor layouts in terms of mesh dimensions allows us to efficiently determine the communication required when computing an operation with tensors of different layouts.
A mesh provides information not only about the placement of the tensors but also the topology of the underlying devices. For example, we can group 8 TPUs as a 1-D array for data parallelism or a 2x4 grid for (2-way) data parallelism and (4-way) model parallelism.
Note: the utilitiesdtensor.create_meshanddtensor.create_distributed_meshprovide a simpler API to create meshes for single- or multi-client use cases.
| Args | |
|---|---|
dim_names | A list of strings indicating dimension names. |
global_device_ids | An ndarray of global device IDs is used to compose DeviceSpecs describing the mesh. The shape of this array determines the size of each mesh dimension. Values in this array should increment sequentially from 0. This argument is the same for every DTensor client. |
local_device_ids | A list of local device IDs equal to a subset of values in global_device_ids. They indicate the position of local devices in the global mesh. Different DTensor clients must contain distinct local_device_ids contents. All local_device_ids from all DTensor clients must cover every element in global_device_ids. |
local_devices | The list of devices hosted locally. The elements correspond 1:1 to those of local_device_ids. |
mesh_name | The name of the mesh. Currently, this is rarely used, and is mostly used to indicate whether it is a CPU, GPU, or TPU-based mesh. global_devices (optional): The list of global devices. Set when multiple device meshes are in use. |
| Attributes | |
|---|---|
dim_names | |
name | |
size | |
as_protoas_proto() -> layout_pb2.MeshProto
Returns mesh protobuffer.
contains_dim
contains_dim(
dim_name: str
) -> bool
Returns True if a Mesh contains the given dimension name.
device_typedevice_type() -> str
Returns the device_type of a Mesh.
dim_size
dim_size(
dim_name: str
) -> int
Returns the size of a dimension.
from_proto
@staticmethod
from_proto(
proto: layout_pb2.MeshProto
) -> 'Mesh'
Construct a mesh instance from input proto.
from_string
@staticmethod
from_string(
mesh_str: str
) -> 'Mesh'
Construct a mesh instance from input proto.
host_meshhost_mesh()
Returns the 1-1 mapped host mesh.
is_remoteis_remote() -> bool
Returns True if a Mesh contains only remote devices.
local_device_idslocal_device_ids() -> List[int]
Returns a list of local device IDs.
local_device_locationslocal_device_locations() -> List[Dict[str, int]]
Returns a list of local device locations.
A device location is a dictionary from dimension names to indices on those dimensions.
local_deviceslocal_devices() -> List[str]
Returns a list of local device specs represented as strings.
min_global_device_idmin_global_device_id() -> int
Returns the minimum global device ID.
num_local_devicesnum_local_devices() -> int
Returns the number of local devices.
shapeshape() -> List[int]
Returns the shape of the mesh.
to_stringto_string() -> str
Returns string representation of Mesh.
unravel_indexunravel_index()
Returns a dictionary from device ID to {dim_name: dim_index}.
For example, for a 3x2 mesh, return this:
{ 0: {'x': 0, 'y', 0},
1: {'x': 0, 'y', 1},
2: {'x': 1, 'y', 0},
3: {'x': 1, 'y', 1},
4: {'x': 2, 'y', 0},
5: {'x': 2, 'y', 1} }
__contains__
__contains__(
dim_name: str
) -> bool
__eq__
__eq__(
other
)
Return self==value.
__getitem__
__getitem__(
dim_name: str
) -> MeshDimension
© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/experimental/dtensor/Mesh