conv

Contents

conv#

class braintrace.conv(x, kernel, bias=None, *, strides=(1,), padding='SAME', lhs_dilation=None, rhs_dilation=None, feature_group_count=1, batch_group_count=1, dimension_numbers=None)[source]#

ETP-aware convolution.

Computes \(y = \mathrm{conv}(x, kernel) \; (+ b)\) by routing the kernel (and optional bias) through an ETP primitive so they participate in eligibility-trace computation. The full keyword surface of jax.lax.conv_general_dilated() is preserved. Always expects a batch dimension on x.

Parameters:
  • x (ArrayLike) – Input tensor with a leading batch dimension.

  • kernel (ArrayLike) – Convolution kernel, with layout governed by dimension_numbers.

  • bias (ArrayLike or None, optional) – Per-output-channel bias. Default None.

  • strides (Sequence[int]) – Window strides. Default (1,).

  • padding (str) – Padding mode (e.g. 'SAME' or 'VALID'). Default 'SAME'.

  • lhs_dilation (Optional[Sequence[int]]) – Left-hand-side (input) dilation factors. Default None.

  • rhs_dilation (Optional[Sequence[int]]) – Right-hand-side (kernel) dilation factors. Default None.

  • feature_group_count (int) – Number of feature groups. Default 1.

  • batch_group_count (int) – Number of batch groups. Default 1.

  • dimension_numbers (Any) – Convolution dimension numbers (e.g. ('NHWC', 'HWIO', 'NHWC')). Default None, which uses the JAX default layout.

Returns:

Convolution output tensor.

Return type:

ArrayLike

Examples

>>> import brainstate
>>> import braintrace
>>>
>>> brainstate.environ.set(precision=64)
>>> # 1-D conv, NCH input and OIH kernel (JAX defaults)
>>> x = brainstate.random.randn(8, 3, 16)
>>> kernel = brainstate.random.randn(4, 3, 5)
>>> y = braintrace.conv(x, kernel, strides=(1,), padding='SAME')
>>> print(y.shape)
(8, 4, 16)