from IPython.display import HTML, display
from jax.experimental.pallas import tpu as pltpu
headers = [
"Version", "Generation", "TensorCores/Chip", "VMEM Capacity", "CMEM Capacity",
"SMEM Capacity", "HBM Capacity", "HBM BW", "BF16 Peak", "FP8 Peak", "INT8 Peak", "INT4 Peak", "SparseCore"
]
html_lines = []
html_lines.append("<style>")
html_lines.append(" .bd-article {")
html_lines.append(" width: 1300px !important;")
html_lines.append(" }")
html_lines.append(" .tpu-spec-table {")
html_lines.append(" font-family: 'Google Sans', Arial, sans-serif;")
html_lines.append(" border-collapse: collapse;")
html_lines.append(" width: 100%;")
html_lines.append(" margin: 20px 0;")
html_lines.append(" font-size: 14px;")
html_lines.append(" }")
html_lines.append(" .tpu-spec-table th {")
html_lines.append(" background-color: #3c4043;")
html_lines.append(" color: white;")
html_lines.append(" text-align: left;")
html_lines.append(" padding: 12px 16px;")
html_lines.append(" font-weight: 500;")
html_lines.append(" border: 1px solid #dadce0;")
html_lines.append(" }")
html_lines.append(" .tpu-spec-table td {")
html_lines.append(" padding: 12px 16px;")
html_lines.append(" border: 1px solid #dadce0;")
html_lines.append(" color: #3c4043;")
html_lines.append(" }")
html_lines.append(" .tpu-spec-table tr:nth-child(even) {")
html_lines.append(" background-color: #f8f9fa;")
html_lines.append(" }")
html_lines.append(" .tpu-spec-table tr:hover {")
html_lines.append(" background-color: #f1f3f4;")
html_lines.append(" }")
html_lines.append("</style>")
html_lines.append("<table class='tpu-spec-table'>")
html_lines.append(" <thead>")
html_lines.append(" <tr>")
for h in headers:
html_lines.append(f" <th>{h}</th>")
html_lines.append(" </tr>")
html_lines.append(" </thead>")
html_lines.append(" <tbody>")
for cv in pltpu.ChipVersion:
if cv == pltpu.ChipVersion.TPU_7:
continue # Skip TPU 7 as it is redundant with 7x
# Get per-TensorCore specs (num_cores=1)
info = pltpu.get_tpu_info_for_chip(cv, 1)
sc = info.sparse_core
sc_str = "No"
if sc is not None:
sc_str = f"Yes ({sc.num_cores} SCs, {sc.num_subcores} subcores, {sc.vmem_capacity_bytes // 1024} KiB VMEM)"
row = [
cv.value.upper(),
f"TPU v{info.generation}" if info.generation < 7 else f"TPU {info.generation}",
str(cv.num_physical_tensor_cores_per_chip),
f"{info.vmem_capacity_bytes // (1024 * 1024)} MiB",
f"{info.cmem_capacity_bytes // 1024} KiB" if info.cmem_capacity_bytes > 0 else "N/A",
f"{info.smem_capacity_bytes // 1024} KiB",
f"{info.hbm_capacity_bytes // 1000000000} GB",
f"{info.mem_bw_bytes_per_second / 1e9:.1f} GB/s" if info.mem_bw_bytes_per_second > 0 else "N/A",
f"{int(round(info.bf16_ops_per_second / 1e12))} TFLOPs/s" if info.bf16_ops_per_second > 0 else "N/A",
f"{int(round(info.fp8_ops_per_second / 1e12))} TFLOPs/s" if info.fp8_ops_per_second > 0 else "N/A",
f"{int(round(info.int8_ops_per_second / 1e12))} TOPs/s" if info.int8_ops_per_second > 0 else "N/A",
f"{int(round(info.int4_ops_per_second / 1e12))} TOPs/s" if info.int4_ops_per_second > 0 else "N/A",
sc_str,
]
html_lines.append(" <tr>")
for cell in row:
html_lines.append(f" <td>{cell}</td>")
html_lines.append(" </tr>")
html_lines.append(" </tbody>")
html_lines.append("</table>")
display(HTML("\n".join(html_lines)))