conv#
- class braintrace.conv(x, kernel, bias=None, *, strides=(1,), padding='SAME', lhs_dilation=None, rhs_dilation=None, feature_group_count=1, batch_group_count=1, dimension_numbers=None)[source]#
ETP-aware convolution.
Computes \(y = \mathrm{conv}(x, kernel) \; (+ b)\) by routing the kernel (and optional bias) through an ETP primitive so they participate in eligibility-trace computation. The full keyword surface of
jax.lax.conv_general_dilated()is preserved. Always expects a batch dimension onx.- Parameters:
x (ArrayLike) – Input tensor with a leading batch dimension.
kernel (ArrayLike) – Convolution kernel, with layout governed by
dimension_numbers.bias (ArrayLike or None, optional) – Per-output-channel bias. Default
None.padding (
str) – Padding mode (e.g.'SAME'or'VALID'). Default'SAME'.lhs_dilation (
Optional[Sequence[int]]) – Left-hand-side (input) dilation factors. DefaultNone.rhs_dilation (
Optional[Sequence[int]]) – Right-hand-side (kernel) dilation factors. DefaultNone.feature_group_count (
int) – Number of feature groups. Default1.batch_group_count (
int) – Number of batch groups. Default1.dimension_numbers (
Any) – Convolution dimension numbers (e.g.('NHWC', 'HWIO', 'NHWC')). DefaultNone, which uses the JAX default layout.
- Returns:
Convolution output tensor.
- Return type:
ArrayLike
Examples
>>> import brainstate >>> import braintrace >>> >>> brainstate.environ.set(precision=64) >>> # 1-D conv, NCH input and OIH kernel (JAX defaults) >>> x = brainstate.random.randn(8, 3, 16) >>> kernel = brainstate.random.randn(4, 3, 5) >>> y = braintrace.conv(x, kernel, strides=(1,), padding='SAME') >>> print(y.shape) (8, 4, 16)