Source code for gfloat.formats

# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

from .block import BlockFormatInfo
from .types import FormatInfo

#: FormatInfo for IEEE-754 Binary64 format
format_info_binary64 = FormatInfo(
    name="binary64",
    k=64,
    precision=53,
    emax=1023,
    has_nz=True,
    has_infs=True,
    num_high_nans=2**52 - 1,
    has_subnormals=True,
    is_signed=True,
    is_twos_complement=False,
)

#: FormatInfo for IEEE-754 Binary32 format
format_info_binary32 = FormatInfo(
    name="binary32",
    k=32,
    precision=24,
    emax=127,
    has_nz=True,
    has_infs=True,
    num_high_nans=2**23 - 1,
    has_subnormals=True,
    is_signed=True,
    is_twos_complement=False,
)

#: FormatInfo for IEEE-754 Binary16 format
format_info_binary16 = FormatInfo(
    name="binary16",
    k=16,
    precision=11,
    emax=15,
    has_nz=True,
    has_infs=True,
    num_high_nans=2**10 - 1,
    has_subnormals=True,
    is_signed=True,
    is_twos_complement=False,
)

#: FormatInfo for Google BFloat16 format
format_info_bfloat16 = FormatInfo(
    name="bfloat16",
    k=16,
    precision=8,
    emax=127,
    has_nz=True,
    has_infs=True,
    num_high_nans=2**7 - 1,
    has_subnormals=True,
    is_signed=True,
    is_twos_complement=False,
)

#: FormatInfo for OCP E5M2 format
format_info_ocp_e5m2 = FormatInfo(
    name="ocp_e5m2",
    k=8,
    precision=3,
    emax=15,
    has_nz=True,
    has_infs=True,
    num_high_nans=2**2 - 1,
    has_subnormals=True,
    is_signed=True,
    is_twos_complement=False,
)

#: FormatInfo for OCP E4M3 format
format_info_ocp_e4m3 = FormatInfo(
    name="ocp_e4m3",
    k=8,
    precision=4,
    emax=8,
    has_nz=True,
    has_infs=False,
    num_high_nans=1,
    has_subnormals=True,
    is_signed=True,
    is_twos_complement=False,
)

#: FormatInfo for OCP MX E2M3 format
format_info_ocp_e2m3 = FormatInfo(
    name="ocp_e2m3",
    k=6,
    precision=4,
    emax=2,
    has_nz=True,
    has_infs=False,
    num_high_nans=0,
    has_subnormals=True,
    is_signed=True,
    is_twos_complement=False,
)

#: FormatInfo for OCP MX E3M2 format
format_info_ocp_e3m2 = FormatInfo(
    name="ocp_e3m2",
    k=6,
    precision=3,
    emax=4,
    has_nz=True,
    has_infs=False,
    num_high_nans=0,
    has_subnormals=True,
    is_signed=True,
    is_twos_complement=False,
)

#: FormatInfo for OCP MX E2M1 format
format_info_ocp_e2m1 = FormatInfo(
    name="ocp_e2m1",
    k=4,
    precision=2,
    emax=2,
    has_nz=True,
    has_infs=False,
    num_high_nans=0,
    has_subnormals=True,
    is_signed=True,
    is_twos_complement=False,
)

#: FormatInfo for OCP MX E8M0 format
format_info_ocp_e8m0 = FormatInfo(
    name="ocp_e8m0",
    k=8,
    precision=1,
    emax=127,
    has_nz=False,
    has_infs=False,
    num_high_nans=1,
    has_subnormals=False,
    is_signed=False,
    is_twos_complement=False,
)

#: FormatInfo for OCP MX INT8 format
format_info_ocp_int8 = FormatInfo(
    name="ocp_int8",
    k=8,
    precision=8,
    emax=0,
    has_nz=False,
    has_infs=False,
    num_high_nans=0,
    has_subnormals=True,
    is_signed=True,
    is_twos_complement=True,
)


[docs] def format_info_p3109(precision: int) -> FormatInfo: """ FormatInfo for P3109 P{p} formats Args: p (int): Precision in bits Returns: FormatInfo class describing the format Raises: ValueError: If p is not in 1..7 """ if precision < 1 or precision > 7: raise ValueError(f"P3109 format not defined for p={precision}") name = f"p3109_p{precision}" emax = 2 ** (7 - precision) - 1 return FormatInfo( name, k=8, precision=precision, emax=emax, has_nz=False, has_infs=True, num_high_nans=0, has_subnormals=True, is_signed=True, is_twos_complement=False, )
# Collections of formats _tiny_formats = [ format_info_ocp_e2m1, format_info_ocp_e2m3, format_info_ocp_e3m2, ] p3109_formats = [format_info_p3109(p) for p in range(1, 7)] _fp8_formats = [ format_info_ocp_e4m3, format_info_ocp_e5m2, *p3109_formats, ] _fp16_formats = [ format_info_binary16, format_info_bfloat16, ] all_formats = [ format_info_ocp_e8m0, format_info_ocp_int8, *_tiny_formats, *_fp8_formats, *_fp16_formats, format_info_binary32, format_info_binary64, ] # ------ # Block formats format_info_mxfp8_e5m2 = BlockFormatInfo( "mxfp8_e5m2", format_info_ocp_e5m2, 32, format_info_ocp_e8m0 ) format_info_mxfp8_e4m3 = BlockFormatInfo( "mxfp8_e4m3", format_info_ocp_e4m3, 32, format_info_ocp_e8m0 ) format_info_mxfp6_e3m2 = BlockFormatInfo( "mxfp6_e3m2", format_info_ocp_e3m2, 32, format_info_ocp_e8m0 ) format_info_mxfp6_e2m3 = BlockFormatInfo( "mxfp6_e2m3", format_info_ocp_e2m3, 32, format_info_ocp_e8m0 ) format_info_mxfp4_e2m1 = BlockFormatInfo( "mxfp4_e2m1", format_info_ocp_e2m1, 32, format_info_ocp_e8m0 ) format_info_mxfp4_e2m1 = BlockFormatInfo( "mxfp4_e2m1", format_info_ocp_e2m1, 32, format_info_ocp_e8m0 ) format_info_mxint8 = BlockFormatInfo( "mxint8", format_info_ocp_int8, 32, format_info_ocp_e8m0 ) all_block_formats = [ format_info_mxfp8_e5m2, format_info_mxfp8_e4m3, format_info_mxfp6_e3m2, format_info_mxfp6_e2m3, format_info_mxfp4_e2m1, format_info_mxint8, ]