Skip to yearly menu bar Skip to main content


Poster
in
Workshop: UniReps: Unifying Representations in Neural Models

Multi-task Learning yields Disentangled World Models: Impact and Implications

Pantelis Vafidis · Aman Bhargava · Antonio Rangel

Keywords: [ transformers ] [ world models ] [ representation learning ] [ computational neuroscience ] [ interpretability ] [ multi-task learning ] [ disentanglement ] [ continuous attractors ] [ feature-based generalization ] [ RNNs ] [ zero-shot generalization ]


Abstract:

Intelligent perception and interaction with the world hinges on internal representations that capture its underlying geometry ("disentangled" or "abstract" representations). The ability to form these disentangled representations from high-dimensional, noisy observations is a hallmark of intelligence, observed in both biological and artificial systems. In this opinion paper we highlight unpublished experimental and theoretical results guaranteeing the emergence of disentangled representations in agents that optimally solve multi-task evidence aggregation classification tasks, canonical in the cognitive neuroscience literature. The key conceptual finding is that, by producing accurate multi-task classification estimates, a system implicitly represents a set of coordinates specifying a disentangled, topology-preserving representation of the underlying latent space. Since the theory relies only on the system accurately computing the classification probabilities, we are able to derive a closed-form solution for extracting disentangled representations from any multi-task classification system. The theory provides conditions for the emergence of these representations in terms of noise, number of tasks, and evidence aggregation time, and we experimentally validate the theoretical predictions on RNNs and GPT-2 transformers solving such canonical evidence-aggregation decision-making neuroscience tasks. We find that transformers are particularly suited for disentangling representations, which might explain their unique world understanding abilities. Overall, our opinion paper puts forth parallel processing as a general principle for the formation of cognitive maps that capture the structure of the world and that are shared across both biological and artificial systems, and helps explain why ANNs often arrive at human-interpretable concepts, and how they both may acquire exceptional zero-shot generalization capabilities. We discuss implications of these findings, for machine learning and neuroscience alike.

Chat is not available.