From 61749f7bfdb3a38d356f93f6775d6ff54c2fcc0f Mon Sep 17 00:00:00 2001 From: Catherine Date: Tue, 29 Apr 2025 02:53:59 +0000 Subject: [PATCH] jtag.tap: implement IEEE 1149.1 boundary scan TAP controller. Co-authored-by: bin --- amaranth_stdio/jtag/__init__.py | 0 amaranth_stdio/jtag/tap.py | 235 ++++++++++++++++++++++++++++++++ tests/test_tap.py | 77 +++++++++++ 3 files changed, 312 insertions(+) create mode 100644 amaranth_stdio/jtag/__init__.py create mode 100644 amaranth_stdio/jtag/tap.py create mode 100644 tests/test_tap.py diff --git a/amaranth_stdio/jtag/__init__.py b/amaranth_stdio/jtag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/amaranth_stdio/jtag/tap.py b/amaranth_stdio/jtag/tap.py new file mode 100644 index 0000000..9c5002d --- /dev/null +++ b/amaranth_stdio/jtag/tap.py @@ -0,0 +1,235 @@ +from typing import Iterable + +from amaranth import * +from amaranth.lib import enum, data, wiring, io, cdc +from amaranth.lib.wiring import In, Out + + +__all__ = ["State", "DataRegister", "Controller"] + + +class State(enum.Enum, shape=unsigned(4)): + Test_Logic_Reset = 0x0 + Run_Test_Idle = 0x8 + + Select_DR_Scan = 0x1 + Capture_DR = 0x2 + Shift_DR = 0x3 + Exit1_DR = 0x4 + Pause_DR = 0x5 + Exit2_DR = 0x6 + Update_DR = 0x7 + + Select_IR_Scan = 0x9 + Capture_IR = 0xA + Shift_IR = 0xB + Exit1_IR = 0xC + Pause_IR = 0xD + Exit2_IR = 0xE + Update_IR = 0xF + + +class DataRegister(wiring.PureInterface): + def __init__(self, length): + assert length >= 1, "DR must be at least 1 bit long" + + self._length = length + + super().__init__(wiring.Signature({ + "cap": In(length), + "upd": Out(length), + })) + + @property + def length(self): + return self._length + + +class Controller(wiring.Component): + def __init__(self, *, ir_length, ir_idcode=None): + assert ir_length >= 2, "IR must be at least 2 bits long" + + self._ir_length = ir_length + self._drs = dict() + + if ir_idcode is not None: + self._dr_idcode = self.add({ir_idcode}, length=32) + else: + self._dr_idcode = None + + super().__init__({ + # TRST# is implicit in the (asynchronous) reset signal of the `jtag` clock domain. + # TCK is implicit in the clock signal `jtag` clock domain. + "tms": Out(io.Buffer.Signature("i", 1)), + "tdi": Out(io.Buffer.Signature("i", 1)), + "tdo": Out(io.Buffer.Signature("o", 1)), + + # TAP state. + "state": Out(State, init=State.Test_Logic_Reset), + + # The high bits of the value loaded into the IR scan chain in the Capture-IR state. + # The low bits are fixed at `01` (with 1 loaded into the least significant bit). + "ir_cap": In(ir_length - 2), + # The last value loaded into the IR scan chain in the Update-IR state; in other words, + # the contents of the instruction register. + "ir_upd": Out(ir_length, init=~0 if ir_idcode is None else ir_idcode), + }) + + @property + def ir_length(self) -> int: + return self._ir_length + + @property + def dr_idcode(self) -> DataRegister: + return self._dr_idcode + + def add(self, ir_values: Iterable[int], *, length: int) -> DataRegister: + ir_values = set(ir_values) + + for ir_value in ir_values: + assert ir_value in range(0, 1 << self._ir_length), "IR value must be within range" + assert ir_value != ((1 << self._ir_length) - 1), "IR value must not be all-ones" + for used_ir_values in self._drs.values(): + assert not (ir_values & used_ir_values), "IR values must be unused" + + dr = DataRegister(length) + self._drs[dr] = ir_values + return dr + + def elaborate(self, platform): + m = Module() + + with m.Switch(self.state): + with m.Case(State.Test_Logic_Reset): + with m.If(~self.tms.i): + m.d.jtag += self.state.eq(State.Run_Test_Idle) + + with m.Case(State.Run_Test_Idle): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Select_DR_Scan) + + with m.Case(State.Select_DR_Scan): + with m.If(~self.tms.i): + m.d.jtag += self.state.eq(State.Capture_DR) + with m.Else(): + m.d.jtag += self.state.eq(State.Select_IR_Scan) + + with m.Case(State.Capture_DR): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Exit1_DR) + with m.Else(): + m.d.jtag += self.state.eq(State.Shift_DR) + + with m.Case(State.Shift_DR): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Exit1_DR) + + with m.Case(State.Exit1_DR): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Update_DR) + with m.Else(): + m.d.jtag += self.state.eq(State.Pause_DR) + + with m.Case(State.Pause_DR): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Exit2_DR) + + with m.Case(State.Exit2_DR): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Update_DR) + with m.Else(): + m.d.jtag += self.state.eq(State.Shift_DR) + + with m.Case(State.Update_DR): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Select_DR_Scan) + with m.Else(): + m.d.jtag += self.state.eq(State.Run_Test_Idle) + + with m.Case(State.Select_IR_Scan): + with m.If(~self.tms.i): + m.d.jtag += self.state.eq(State.Capture_IR) + with m.Else(): + m.d.jtag += self.state.eq(State.Test_Logic_Reset) + + with m.Case(State.Capture_IR): + with m.If(~self.tms.i): + m.d.jtag += self.state.eq(State.Shift_IR) + with m.Else(): + m.d.jtag += self.state.eq(State.Exit1_IR) + + with m.Case(State.Shift_IR): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Exit1_IR) + + with m.Case(State.Exit1_IR): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Update_IR) + with m.Else(): + m.d.jtag += self.state.eq(State.Pause_IR) + + with m.Case(State.Pause_IR): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Exit2_IR) + + with m.Case(State.Exit2_IR): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Update_IR) + with m.Else(): + m.d.jtag += self.state.eq(State.Shift_IR) + + with m.Case(State.Update_IR): + with m.If(self.tms.i): + m.d.jtag += self.state.eq(State.Select_DR_Scan) + with m.Else(): + m.d.jtag += self.state.eq(State.Run_Test_Idle) + + dr_chain = Signal(max([1, *(dr.length for dr in self._drs)])) + ir_chain = Signal(self.ir_length) + + with m.Switch(self.state): + m.d.comb += self.tdo.oe.eq(0) + + with m.Case(State.Test_Logic_Reset): + m.d.jtag += self.ir_upd.eq(self.ir_upd.init) + for dr, ir_values in self._drs.items(): + m.d.jtag += dr.upd.eq(dr.upd.init) + + with m.Case(State.Capture_DR): + with m.Switch(self.ir_upd): + for dr, ir_values in self._drs.items(): + with m.Case(*ir_values): + m.d.jtag += dr_chain[-dr.length:].eq(dr.cap) + with m.Default(): # BYPASS + m.d.jtag += dr_chain.eq(0) + + with m.Case(State.Shift_DR): + m.d.jtag += dr_chain.eq(Cat(dr_chain[1:], self.tdi.i)) + with m.Switch(self.ir_upd): + for dr, ir_values in self._drs.items(): + with m.Case(*ir_values): + m.d.comb += self.tdo.o.eq(dr_chain[-dr.length]) + with m.Default(): # BYPASS + m.d.comb += self.tdo.o.eq(dr_chain[-1]) + m.d.comb += self.tdo.oe.eq(1) + + with m.Case(State.Update_DR): + with m.Switch(self.ir_upd): + for dr, ir_values in self._drs.items(): + with m.Case(*ir_values): + m.d.jtag += dr.upd.eq(dr_chain[-dr.length:]) + with m.Default(): # BYPASS + pass + + with m.Case(State.Capture_IR): + m.d.jtag += ir_chain.eq(Cat(1, 0, self.ir_cap)) + + with m.Case(State.Shift_IR): + m.d.jtag += ir_chain.eq(Cat(ir_chain[1:], self.tdi.i)) + m.d.comb += self.tdo.o.eq(ir_chain[0]) + m.d.comb += self.tdo.oe.eq(1) + + with m.Case(State.Update_IR): + m.d.jtag += self.ir_upd.eq(ir_chain) + + return m diff --git a/tests/test_tap.py b/tests/test_tap.py new file mode 100644 index 0000000..bb8e2c4 --- /dev/null +++ b/tests/test_tap.py @@ -0,0 +1,77 @@ +import functools +import unittest + +from amaranth import * +from amaranth.sim import Simulator +from amaranth_stdio.jtag import tap + + +async def shift_tms(ctx, dut, tms, state_after, *, expected={}): + ctx.set(dut.tms.i, tms) + + # HACK(bin): i'm so sorry? + (_, _, *sampled) = await ctx.tick("jtag").sample(*[getattr(dut, s).o for s in expected.keys()]) + assert ctx.get(dut.state == state_after) + + for (dut_value, (name, expected_value)) in zip(sampled, expected.items()): + assert dut_value == expected_value, f"dut.{name} != {expected_value:#b}" + + +class TAPTestCase(unittest.TestCase): + def test_tap_controller(self): + ir_idcode = 0b10101010 + dr_idcode = 0b0011_1111000011110000_00001010100_1 + + m = Module() + m.submodules.dut = dut = tap.Controller(ir_length=8, ir_idcode=ir_idcode) + m.d.comb += dut.dr_idcode.cap.eq(dr_idcode) + + async def testbench(ctx): + global shift_tms + shift_tms = functools.partial(shift_tms, ctx, dut) + + assert ctx.get(dut.state) == tap.State.Test_Logic_Reset + + await shift_tms(0, tap.State.Run_Test_Idle) + await shift_tms(1, tap.State.Select_DR_Scan) + await shift_tms(0, tap.State.Capture_DR) + await shift_tms(0, tap.State.Shift_DR) + + for i in range(32): + await shift_tms(0, tap.State.Shift_DR, expected={ + "tdo": (dr_idcode >> i) & 1 + }) + + await shift_tms(1, tap.State.Exit1_DR) + await shift_tms(0, tap.State.Pause_DR) + await shift_tms(1, tap.State.Exit2_DR) + await shift_tms(1, tap.State.Update_DR) + await shift_tms(1, tap.State.Select_DR_Scan) + await shift_tms(1, tap.State.Select_IR_Scan) + + ctx.set(dut.ir_cap, 0b111111) + await shift_tms(0, tap.State.Capture_IR) + await shift_tms(0, tap.State.Shift_IR) + await shift_tms(0, tap.State.Shift_IR, expected={ + "tdo": 0b1 + }) + await shift_tms(1, tap.State.Exit1_IR, expected={ + "tdo": 0b0 + }) + await shift_tms(0, tap.State.Pause_IR) + await shift_tms(1, tap.State.Exit2_IR) + await shift_tms(1, tap.State.Update_IR) + await shift_tms(1, tap.State.Select_DR_Scan) + assert ctx.get(dut.ir_upd) == 0b111111 + + await shift_tms(1, tap.State.Select_IR_Scan) + await shift_tms(1, tap.State.Test_Logic_Reset) + + await shift_tms(1, tap.State.Test_Logic_Reset) + assert ctx.get(dut.ir_upd) == ir_idcode + + sim = Simulator(m) + sim.add_clock(1e-3, domain="jtag") + sim.add_testbench(testbench) + with sim.write_vcd("test_tap.vcd"): + sim.run()