From d218273b9b2c6e65b7d92eb0f280306ea9c07ea3 Mon Sep 17 00:00:00 2001 From: Charlotte Date: Thu, 22 Jun 2023 11:30:54 +1000 Subject: [PATCH] hdl.ast: deprecate `Repl` and remove from AST; add `Value.replicate`. --- amaranth/back/rtlil.py | 6 --- amaranth/compat/fhdl/structure.py | 4 +- amaranth/hdl/ast.py | 62 ++++++++++++++++--------------- amaranth/hdl/mem.py | 2 +- amaranth/hdl/xfrm.py | 12 ------ amaranth/sim/_pyrtl.py | 15 -------- amaranth/vendor/intel.py | 2 +- amaranth/vendor/lattice_ecp5.py | 2 +- docs/changes.rst | 4 ++ docs/lang.rst | 4 +- tests/test_hdl_ast.py | 39 ++++++++++--------- tests/test_hdl_xfrm.py | 17 ++++++++- tests/test_sim.py | 9 +---- 13 files changed, 83 insertions(+), 95 deletions(-) diff --git a/amaranth/back/rtlil.py b/amaranth/back/rtlil.py index c605f3c76..a9f026913 100644 --- a/amaranth/back/rtlil.py +++ b/amaranth/back/rtlil.py @@ -632,9 +632,6 @@ def on_Part(self, value): }, src=_src(value.src_loc)) return res - def on_Repl(self, value): - return "{{ {} }}".format(" ".join(self(value.value) for _ in range(value.count))) - class _LHSValueCompiler(_ValueCompiler): def on_Const(self, value): @@ -695,9 +692,6 @@ def on_Part(self, value): range(1 << len(value.offset))[:max_branches], value.src_loc) - def on_Repl(self, value): - raise TypeError # :nocov: - class _StatementCompiler(xfrm.StatementVisitor): def __init__(self, state, rhs_compiler, lhs_compiler): diff --git a/amaranth/compat/fhdl/structure.py b/amaranth/compat/fhdl/structure.py index c0d4434c9..9aa708736 100644 --- a/amaranth/compat/fhdl/structure.py +++ b/amaranth/compat/fhdl/structure.py @@ -69,9 +69,9 @@ def Constant(value, bits_sign=None): return Const(value, bits_sign) -@deprecated("instead of `Replicate`, use `Repl`") +@deprecated("instead of `Replicate(v, n)`, use `v.replicate(n)`") def Replicate(v, n): - return Repl(v, n) + return v.replicate(n) @extend(Const) diff --git a/amaranth/hdl/ast.py b/amaranth/hdl/ast.py index b431a6de8..0a728c68a 100644 --- a/amaranth/hdl/ast.py +++ b/amaranth/hdl/ast.py @@ -539,6 +539,29 @@ def rotate_right(self, amount): amount %= len(self) return Cat(self[amount:], self[:amount]) + def replicate(self, count): + """Replication. + + A ``Value`` is replicated (repeated) several times to be used + on the RHS of assignments:: + + len(v.replicate(n)) == len(v) * n + + Parameters + ---------- + count : int + Number of replications. + + Returns + ------- + Value, out + Replicated value. + """ + if not isinstance(count, int) or count < 0: + raise TypeError("Replication count must be a non-negative integer, not {!r}" + .format(count)) + return Cat(self for _ in range(count)) + def eq(self, value): """Assignment. @@ -914,8 +937,9 @@ def __repr__(self): return "(cat {})".format(" ".join(map(repr, self.parts))) -@final -class Repl(Value): +# TODO(amaranth-0.5): remove +@deprecated("instead of `Repl(value, count)`, use `value.replicate(count)`") +def Repl(value, count): """Replicate a value An input value is replicated (repeated) several times @@ -932,31 +956,16 @@ class Repl(Value): Returns ------- - Repl, out + Value, out Replicated value. """ - def __init__(self, value, count, *, src_loc_at=0): - if not isinstance(count, int) or count < 0: - raise TypeError("Replication count must be a non-negative integer, not {!r}" - .format(count)) + if isinstance(value, int) and value not in [0, 1]: + warnings.warn("Value argument of Repl() is a bare integer {} used in bit vector " + "context; consider specifying explicit width using C({}, {}) instead" + .format(value, value, bits_for(value)), + SyntaxWarning, stacklevel=3) - super().__init__(src_loc_at=src_loc_at) - if isinstance(value, int) and value not in [0, 1]: - warnings.warn("Value argument of Repl() is a bare integer {} used in bit vector " - "context; consider specifying explicit width using C({}, {}) instead" - .format(value, value, bits_for(value)), - SyntaxWarning, stacklevel=2 + src_loc_at) - self.value = Value.cast(value) - self.count = count - - def shape(self): - return Shape(len(self.value) * self.count) - - def _rhs_signals(self): - return self.value._rhs_signals() - - def __repr__(self): - return "(repl {!r} {})".format(self.value, self.count) + return Value.cast(value).replicate(count) class _SignalMeta(ABCMeta): @@ -1728,8 +1737,6 @@ def __init__(self, value): tuple(ValueKey(e) for e in self.value._iter_as_values()))) elif isinstance(self.value, Sample): self._hash = hash((ValueKey(self.value.value), self.value.clocks, self.value.domain)) - elif isinstance(self.value, Repl): - self._hash = hash((ValueKey(self.value.value), self.value.count)) elif isinstance(self.value, Initial): self._hash = 0 else: # :nocov: @@ -1769,9 +1776,6 @@ def __eq__(self, other): return (len(self.value.parts) == len(other.value.parts) and all(ValueKey(a) == ValueKey(b) for a, b in zip(self.value.parts, other.value.parts))) - elif isinstance(self.value, Repl): - return (ValueKey(self.value.value) == ValueKey(other.value.value) and - self.value.count == other.value.count) elif isinstance(self.value, ArrayProxy): return (ValueKey(self.value.index) == ValueKey(other.value.index) and len(self.value.elems) == len(other.value.elems) and diff --git a/amaranth/hdl/mem.py b/amaranth/hdl/mem.py index fd7c8745f..58bf6da84 100644 --- a/amaranth/hdl/mem.py +++ b/amaranth/hdl/mem.py @@ -281,7 +281,7 @@ def elaborate(self, platform): p_CLK_POLARITY=1, p_PRIORITY=0, i_CLK=ClockSignal(self.domain), - i_EN=Cat(Repl(en_bit, self.granularity) for en_bit in self.en), + i_EN=Cat(en_bit.replicate(self.granularity) for en_bit in self.en), i_ADDR=self.addr, i_DATA=self.data, ) diff --git a/amaranth/hdl/xfrm.py b/amaranth/hdl/xfrm.py index d49f30115..fdd25944e 100644 --- a/amaranth/hdl/xfrm.py +++ b/amaranth/hdl/xfrm.py @@ -62,10 +62,6 @@ def on_Part(self, value): def on_Cat(self, value): pass # :nocov: - @abstractmethod - def on_Repl(self, value): - pass # :nocov: - @abstractmethod def on_ArrayProxy(self, value): pass # :nocov: @@ -106,8 +102,6 @@ def on_value(self, value): new_value = self.on_Part(value) elif type(value) is Cat: new_value = self.on_Cat(value) - elif type(value) is Repl: - new_value = self.on_Repl(value) elif type(value) is ArrayProxy: new_value = self.on_ArrayProxy(value) elif type(value) is Sample: @@ -156,9 +150,6 @@ def on_Part(self, value): def on_Cat(self, value): return Cat(self.on_value(o) for o in value.parts) - def on_Repl(self, value): - return Repl(self.on_value(value.value), value.count) - def on_ArrayProxy(self, value): return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()], self.on_value(value.index)) @@ -374,9 +365,6 @@ def on_Cat(self, value): for o in value.parts: self.on_value(o) - def on_Repl(self, value): - self.on_value(value.value) - def on_ArrayProxy(self, value): for elem in value._iter_as_values(): self.on_value(elem) diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index 0e3784771..fb182981c 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -213,18 +213,6 @@ def on_Cat(self, value): return f"({' | '.join(gen_parts)})" return f"0" - def on_Repl(self, value): - part_mask = (1 << len(value.value)) - 1 - gen_part = self.emitter.def_var("repl", f"{part_mask:#x} & {self(value.value)}") - gen_parts = [] - offset = 0 - for _ in range(value.count): - gen_parts.append(f"({gen_part} << {offset})") - offset += len(value.value) - if gen_parts: - return f"({' | '.join(gen_parts)})" - return f"0" - def on_ArrayProxy(self, value): index_mask = (1 << len(value.index)) - 1 gen_index = self.emitter.def_var("rhs_index", f"{index_mask:#x} & {self(value.index)}") @@ -325,9 +313,6 @@ def gen(arg): offset += len(part) return gen - def on_Repl(self, value): - raise TypeError # :nocov: - def on_ArrayProxy(self, value): def gen(arg): index_mask = (1 << len(value.index)) - 1 diff --git a/amaranth/vendor/intel.py b/amaranth/vendor/intel.py index 131026a6b..797fbe67b 100644 --- a/amaranth/vendor/intel.py +++ b/amaranth/vendor/intel.py @@ -375,7 +375,7 @@ def get_oneg(o): def _get_oereg(m, pin): # altiobuf_ requires an output enable signal for each pin, but pin.oe is 1 bit wide. if pin.xdr == 0: - return Repl(pin.oe, pin.width) + return pin.oe.replicate(pin.width) elif pin.xdr in (1, 2): oe_reg = Signal(pin.width, name="{}_oe_reg".format(pin.name)) oe_reg.attrs["useioff"] = "1" diff --git a/amaranth/vendor/lattice_ecp5.py b/amaranth/vendor/lattice_ecp5.py index b5ebc1a9c..2b5410776 100644 --- a/amaranth/vendor/lattice_ecp5.py +++ b/amaranth/vendor/lattice_ecp5.py @@ -528,7 +528,7 @@ def get_oneg(a, invert): if "o" in pin.dir: o = pin_o if pin.dir in ("oe", "io"): - t = Repl(~pin.oe, pin.width) + t = (~pin.oe).replicate(pin.width) elif pin.xdr == 1: if "i" in pin.dir: get_ireg(pin.i_clk, i, pin_i) diff --git a/docs/changes.rst b/docs/changes.rst index 5e55a3416..861840751 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -37,6 +37,7 @@ Implemented RFCs .. _RFC 5: https://amaranth-lang.org/rfcs/0005-remove-const-normalize.html .. _RFC 8: https://amaranth-lang.org/rfcs/0008-aggregate-extensibility.html .. _RFC 9: https://amaranth-lang.org/rfcs/0009-const-init-shape-castable.html +.. _RFC 10: https://amaranth-lang.org/rfcs/0010-move-repl-to-value.html .. _RFC 15: https://amaranth-lang.org/rfcs/0015-lifting-shape-castables.html * `RFC 1`_: Aggregate data structure library @@ -45,6 +46,7 @@ Implemented RFCs * `RFC 5`_: Remove Const.normalize * `RFC 8`_: Aggregate extensibility * `RFC 9`_: Constant initialization for shape-castable objects +* `RFC 10`_: Move Repl to Value.replicate * `RFC 15`_: Lifting shape-castable objects @@ -57,12 +59,14 @@ Language changes * Added: :meth:`Value.as_signed` and :meth:`Value.as_unsigned` can be used on left-hand side of assignment (with no difference in behavior). * Added: :meth:`Const.cast`. (`RFC 4`_) * Added: :meth:`Value.matches` and ``with m.Case():`` accept any constant-castable objects. (`RFC 4`_) +* Added: :meth:`Value.replicate`, superseding :class:`Repl`. (`RFC 10`_) * Changed: creating a :class:`Signal` with a shape that is a :class:`ShapeCastable` implementing :meth:`ShapeCastable.__call__` wraps the returned object using that method. (`RFC 15`_) * Changed: :meth:`Value.cast` casts :class:`ValueCastable` objects recursively. * Changed: :meth:`Value.cast` treats instances of classes derived from both :class:`enum.Enum` and :class:`int` (including :class:`enum.IntEnum`) as enumerations rather than integers. * Changed: :meth:`Value.matches` with an empty list of patterns returns ``Const(1)`` rather than ``Const(0)``, to match the behavior of ``with m.Case():``. * Changed: :class:`Cat` warns if an enumeration without an explicitly specified shape is used. (`RFC 3`_) * Deprecated: :meth:`Const.normalize`. (`RFC 5`_) +* Deprecated: :class:`Repl`; use :meth:`Value.replicate` instead. (`RFC 10`_) * Removed: (deprecated in 0.1) casting of :class:`Shape` to and from a ``(width, signed)`` tuple. * Removed: (deprecated in 0.3) :class:`ast.UserValue`. * Removed: (deprecated in 0.3) support for ``# nmigen:`` linter instructions at the beginning of file. diff --git a/docs/lang.rst b/docs/lang.rst index 0ab5973a9..cc22baf12 100644 --- a/docs/lang.rst +++ b/docs/lang.rst @@ -705,7 +705,7 @@ Operation Description Notes ``a.bit_select(b, w)`` overlapping part select with variable offset ``a.word_select(b, w)`` non-overlapping part select with variable offset ``Cat(a, b)`` concatenation [#opS3]_ -``Repl(a, n)`` replication +``a.replicate(n)`` replication ======================= ================================================ ====== .. [#opS1] Words "length" and "width" have the same meaning when talking about Amaranth values. Conventionally, "width" is used. @@ -718,7 +718,7 @@ For the operators introduced by Amaranth, the following table explains them in t Amaranth operation Equivalent Python code ======================= ====================== ``Cat(a, b)`` ``a + b`` -``Repl(a, n)`` ``a * n`` +``a.replicate(n)`` ``a * n`` ``a.bit_select(b, w)`` ``a[b:b+w]`` ``a.word_select(b, w)`` ``a[b*w:b*w+w]`` ======================= ====================== diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 5878a6c51..317780710 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -353,6 +353,23 @@ def test_rotate_right_wrong(self): r"^Rotate amount must be an integer, not 'str'$"): Const(31).rotate_right("str") + def test_replicate_shape(self): + s1 = Const(10).replicate(3) + self.assertEqual(s1.shape(), unsigned(12)) + self.assertIsInstance(s1.shape(), Shape) + s2 = Const(10).replicate(0) + self.assertEqual(s2.shape(), unsigned(0)) + + def test_replicate_count_wrong(self): + with self.assertRaises(TypeError): + Const(10).replicate(-1) + with self.assertRaises(TypeError): + Const(10).replicate("str") + + def test_replicate_repr(self): + s = Const(10).replicate(3) + self.assertEqual(repr(s), "(cat (const 4'd10) (const 4'd10) (const 4'd10))") + class ConstTestCase(FHDLTestCase): def test_shape(self): @@ -863,33 +880,19 @@ def test_int_wrong(self): class ReplTestCase(FHDLTestCase): - def test_shape(self): - s1 = Repl(Const(10), 3) - self.assertEqual(s1.shape(), unsigned(12)) - self.assertIsInstance(s1.shape(), Shape) - s2 = Repl(Const(10), 0) - self.assertEqual(s2.shape(), unsigned(0)) - - def test_count_wrong(self): - with self.assertRaises(TypeError): - Repl(Const(10), -1) - with self.assertRaises(TypeError): - Repl(Const(10), "str") - - def test_repr(self): - s = Repl(Const(10), 3) - self.assertEqual(repr(s), "(repl (const 4'd10) 3)") - + @_ignore_deprecated def test_cast(self): r = Repl(0, 3) - self.assertEqual(repr(r), "(repl (const 1'd0) 3)") + self.assertEqual(repr(r), "(cat (const 1'd0) (const 1'd0) (const 1'd0))") + @_ignore_deprecated def test_int_01(self): with warnings.catch_warnings(): warnings.filterwarnings(action="error", category=SyntaxWarning) Repl(0, 3) Repl(1, 3) + @_ignore_deprecated def test_int_wrong(self): with self.assertWarnsRegex(SyntaxWarning, r"^Value argument of Repl\(\) is a bare integer 2 used in bit vector context; " diff --git a/tests/test_hdl_xfrm.py b/tests/test_hdl_xfrm.py index 5a55eb84a..a8648192b 100644 --- a/tests/test_hdl_xfrm.py +++ b/tests/test_hdl_xfrm.py @@ -556,7 +556,22 @@ def test_enable_write_port(self): mem = Memory(width=8, depth=4) f = EnableInserter(self.c1)(mem.write_port()).elaborate(platform=None) self.assertRepr(f.named_ports["EN"][0], """ - (m (sig c1) (cat (repl (slice (sig mem_w_en) 0:1) 8)) (const 8'd0)) + (m + (sig c1) + (cat + (cat + (slice (sig mem_w_en) 0:1) + (slice (sig mem_w_en) 0:1) + (slice (sig mem_w_en) 0:1) + (slice (sig mem_w_en) 0:1) + (slice (sig mem_w_en) 0:1) + (slice (sig mem_w_en) 0:1) + (slice (sig mem_w_en) 0:1) + (slice (sig mem_w_en) 0:1) + ) + ) + (const 8'd0) + ) """) diff --git a/tests/test_sim.py b/tests/test_sim.py index 80f42958c..99fdf9623 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -289,8 +289,8 @@ def test_record(self): stmt = lambda y, a: [rec.eq(a), y.eq(rec)] self.assertStatement(stmt, [C(0b101, 3)], C(0b101, 3)) - def test_repl(self): - stmt = lambda y, a: y.eq(Repl(a, 3)) + def test_replicate(self): + stmt = lambda y, a: y.eq(a.replicate(3)) self.assertStatement(stmt, [C(0b10, 2)], C(0b101010, 6)) def test_array(self): @@ -879,11 +879,6 @@ def test_bug_325(self): dut.d.comb += Signal().eq(Cat()) Simulator(dut).run() - def test_bug_325_bis(self): - dut = Module() - dut.d.comb += Signal().eq(Repl(Const(1), 0)) - Simulator(dut).run() - def test_bug_473(self): sim = Simulator(Module()) def process():