A Causality Summary Part III

Main Reference: https://arxiv.org/abs/2405.08793



Learning in Latent Variable Models


Latent variable models are powerful extensions of probabilistic models. They introduce hidden or unobserved variables to explain dependencies in data that cannot be captured solely by observed variables. These latent variables often represent underlying processes or structures that are not directly measurable.


Latent Variable Representation

Given a pair of variables (v, v'), the latent variable model introduces an additional unobserved variable u to model the relationship between v and v'. The joint conditional probability is then expressed as:

    \[p(v \mid v'; \theta) = \sum_u p_u(u) p(v \mid v', u; \theta)\]

where:

  • u is the latent (unobserved) variable.
  • p_u(u) is the prior distribution over the latent variable.
  • p(v \mid v', u; \theta) represents the conditional probability of v given v', u, and parameters \theta.

For continuous latent variables, the summation over u is replaced by an integral:

    \[p(v \mid v'; \theta) = \int p_u(u) p(v \mid v', u; \theta) \, du\]

Understanding \theta: Why it Matters

The parameter \theta is central to the definition and learning of latent variable models. It represents the set of model parameters that control the model’s behavior. Specifically:

  1. Parameterizing Conditional Distributions:
    • The distribution p(v \mid v', u; \theta) depends on \theta to define the relationship between the variables.
    • \theta can represent the weights in a neural network, the coefficients in a linear model, or any other model-specific parameters.
  2. Learning from Data:
    • The process of learning involves estimating \theta from a dataset \mathcal{D} = \{(v_1, v_1'), \dots, (v_N, v_N')\} to maximize the likelihood of the observed data.
    • \theta acts as the tunable component of the model, adapting to the underlying data-generating process.
  3. Generalization:
    • A well-optimized \theta enables the model to generalize to unseen data. Poorly chosen parameters can result in overfitting or underfitting.
  4. Flexibility:
    • By incorporating \theta, the model gains the flexibility to represent a wide range of distributions, accommodating different patterns in the data.

Learning and a Generative Process

Learning in Machine Learning

Learning is defined as the process of configuring a predictive model to capture the relationship between variables. Specifically:

  • Input-output relationship:
    • Given input v and output v', the goal is to train a predictive function g_\theta(v') that models the relationship between them. This function outputs the prediction of v given v'.
  • Parameters \theta:
    • These are the tunable components of g_\theta, adjusted during the training process to make g_\theta as predictive as possible.

Dataset and Learning Process

  • The dataset \mathcal{D} consists of pairs (v_1, v_1'), (v_2, v_2'), \dots, (v_N, v_N'), which are examples drawn from an underlying data-generating process G. Learning aims to infer the optimal parameters \theta so that g_\theta(v') best predicts v across the dataset.

To evaluate how well g_\theta(v') captures the relationship between v and v', the log-probability of the data under the model is calculated:

    \[r(\theta; v, v') = \log p(v \mid v'; g_\theta(v'))\]

This log probability measures the likelihood of the observed v, given v' and the predictive function g_\theta. A higher value of r(\theta; v, v') indicates that the model aligns well with the observed data.

The learning process seeks to maximize this log probability across the entire dataset. The optimization objective is:

    \[\arg \max_\theta \frac{1}{N} \sum_{n=1}^N r(\theta; v_n, v_n')\]

where N is the total number of examples. By maximizing this objective, the model learns the parameters \theta that make the predictions most consistent with the data. Once learning is complete, the model provides an approximation of the true conditional probability:

    \[p(v \mid v') \approx p(v \mid v'; g_\theta(v')) = p(v \mid v'; \theta)\]

This approximation is critical, as the true conditional distribution may not be directly accessible or computationally feasible to derive. By learning g_\theta(v'), we achieve a practical representation of p(v \mid v').


Generalizing from Probabilistic Graphical Models

Probabilistic graphical models provide a structured framework to represent dependencies among variables using a graph G = (V, E, P), where:

  • V is the set of variables.
  • E represents directed edges that encode the conditional dependencies between variables.
  • P is the joint probability distribution over all variables in V.

In this context, the goal is to compute the conditional probability p(v \mid v') for two variables v, v' \in V. This can be achieved by marginalizing over all other variables \bar{V} = V \setminus \{v, v'\} in the graph.


Marginalization and Conditioning

The conditional probability p(v \mid v') can be expressed as:

    \[p(v \mid v') = \frac{\sum_{\bar{V}} p(\{v, v'\} \cup \bar{V})}{\sum_{\{v\} \cup \bar{V}} p(\{v'\} \cup V')}\]

Where

  • \bar{V}: The set of all variables in V except v and v'. These variables are marginalized out to focus only on the relationship between v and v'.
  • p(\{v, v'\} \cup \bar{V}): The joint probability over v, v', and the set of remaining variables \bar{V}.
  • p(\{v'\} \cup V'): The marginal probability of v' and the remaining variables V'. This is used to normalize the conditional distribution.

Key insights from the marginalization equation:

  1. Marginalization Process: 
    • The numerator, \sum_{\bar{V}} p(\{v, v'\} \cup \bar{V}), sums out all variables in \bar{V} to focus on the joint distribution of v and v'. This ensures that the conditional probability p(v \mid v') captures the relationship between v and v' without the influence of other variables.
  2. Normalization:
    • The denominator, \sum_{\{v\} \cup \bar{V}} p(\{v'\} \cup V'), ensures that p(v \mid v') is a valid probability distribution, summing to 1 over all possible values of v.
  3. Dependence on the Graph Structure: 
    • The computation of p(v \mid v') depends on the full graph structure G. Even if only v and v' are of interest, the marginalization process involves all other variables in \bar{V} due to their potential influence on v and v'.


Causal Models and Structural Ambiguities


Learning can recover conditional distributions induced by the data-generating process G. However, the same conditional distribution may result from multiple causal models, introducing ambiguities. Consider two structural causal models:


Causal Model 1:

    \[v' \leftarrow \epsilon_{v'}, \quad v^l \leftarrow v' + a + \epsilon_{v^l}, \quad v^r \leftarrow v' + b + \epsilon_{v^r}, \quad v \leftarrow v^l + v^r + \epsilon_v\]

In this model:

  • v' is influenced solely by noise \epsilon_{v'}.
  • v^l and v^r are intermediate variables influenced by v' and respective noise terms \epsilon_{v^l} and \epsilon_{v^r}.
  • v depends on v^l, v^r, and an additional noise term \epsilon_v.


Causal Model 2:

    \[v' \leftarrow \epsilon_{v'}, \quad v^c \leftarrow v' + a + b + \epsilon_{v^c}, \quad v \leftarrow v^c + \epsilon_v\]

Here:

  • v' is again influenced solely by noise \epsilon_{v'}.
  • v^c directly depends on v', constants a and b, and noise \epsilon_{v^c}.
  • v is derived from v^c and noise \epsilon_v.


Ambiguities in Conditional Probabilities

Despite their structural differences, both models produce the same conditional distribution:

    \[p(v \mid v') = \mathcal{N}(v; v' + a + b, 3)\]

This illustrates a fundamental ambiguity: the same conditional probability can emerge from distinct causal structures. Consequently, conditional distributions alone cannot uniquely identify the true causal model.


Implications for Learning and Generalization

  1. Learning from Conditional Distributions:
    • The learning process enables the recovery of conditional distributions from data. However, without additional assumptions, it cannot disambiguate between different causal mechanisms producing the same distribution.
  2. Out-of-Distribution Generalization:
    • Understanding the underlying causal structure is essential for robust generalization to unseen data distributions. Ambiguities in causal models highlight the challenges of making predictions in new environments.
  3. The Role of Structural Causal Models:
    • Structural causal models explicitly represent the mechanisms generating the data, offering insights into the relationships between variables and enabling counterfactual reasoning.

In summary, while learning provides practical tools for approximating conditional distributions, interpreting these distributions in the context of causal inference requires careful consideration of structural assumptions and potential ambiguities.



Learning Objectives and Latent Variable Models


Latent variable models extend probabilistic models by introducing unobserved variables to capture underlying dependencies in the data that are not directly observable.

These models are particularly useful for representing complex relationships when the full data-generating process is unknown or partially observed. This section elaborates on the learning objectives, gradient approximation, and the role of latent variable models in predictive and generative tasks.


1. Learning in Latent Variable Models

Learning in latent variable models revolves around estimating the parameters \theta by maximizing the likelihood of the observed data. The likelihood is expressed as:

    \[p(v \mid v'; \theta) = \sum_u p_u(u) p(v \mid v', u; \theta)\]

where:

  • u is the latent (unobserved) variable.
  • p_u(u) is the prior distribution of the latent variable u.
  • p(v \mid v', u; \theta) is the conditional distribution of v given v' and u, parameterized by \theta.

Challenges:

  1. Marginalization Complexity:
    • Computing the sum \sum_u over all possible values of u is often intractable, especially when u is high-dimensional or continuous.
  2. Gradient-Based Optimization:
    • Maximizing the likelihood requires calculating gradients of the log-likelihood with respect to \theta, which involves terms dependent on u.

2. Sampling-Based Gradient Approximation


To address the computational intractability of marginalization, sampling-based approximations are used to estimate the gradient of the log-likelihood:


Gradient Expression

The gradient of the log-likelihood is given by:

    \[\nabla \log p(v \mid v'; \theta) = \sum_u \frac{p_u(u)p(v \mid v', u; \theta)}{\sum_{u'} p_u(u')p(v \mid v', u'; \theta)} \nabla \log p(v \mid v', u; \theta)\]

Here:

  • The numerator represents the contribution of each latent variable configuration u weighted by its probability under the model.
  • The denominator normalizes these contributions across all possible configurations of u.


Simplification via Posterior Sampling:

Using posterior samples u^m, the gradient is approximated as:

    \[\nabla \log p(v \mid v'; \theta) \approx \frac{1}{M} \sum_{m=1}^M \nabla \log p(v \mid v', u^m; \theta)\]

where u^m are samples from the posterior distribution of u.

  • Intuition:
    • Instead of summing over all possible values of u, the gradient is estimated using a finite number of samples from the posterior.
  • Efficiency:
    • This approach makes the gradient computation feasible for high-dimensional u.

3. Modeling with Latent Variables

Latent variable models capture unobserved factors contributing to the data distribution’s complexity. These models are widely used in tasks where direct observation of all relevant variables is impossible.

Autoregressive Models:

In an autoregressive framework, the joint conditional probability is factorized as:

    \[p(v \mid v'; \theta) = \prod_{i=1}^d p(v_i \mid v'; \theta)\]

where v = [v_1, \dots, v_d] represents a set of variables.

Key Idea: Each conditional probability p(v_i \mid v'; \theta) is more straightforward to model individually, allowing the joint probability to be computed iteratively.


Latent Variable Models:

In contrast, latent variable models introduce an additional variable u to explain dependencies:

    \[p(v \mid v'; \theta) = \sum_u p_u(u)p(v \mid v', u; \theta)\]

  • Role of u:
    • The latent variable u accounts for unobserved influences, providing a more flexible representation of p(v \mid v').
  • Marginalization:
    • Summing over u integrates the effects of all possible configurations of the latent variable.


4. Practical Challenges and Solutions


 Marginalization Challenges:

Marginalizing over u is computationally expensive, especially for continuous latent variables. In these cases, the summation \sum_u is replaced by an integral:

    \[p(v \mid v'; \theta) = \int p_u(u)p(v \mid v', u; \theta) \, du\]


Sampling and Variational Inference:

Sampling-based methods approximate the log-likelihood gradient using posterior samples of u. However, posterior inference can still be challenging. A popular solution is variational inference, where the posterior distribution is approximated using a neural network. This approach is widely used in variational autoencoders (VAEs), which learn an approximate posterior to simplify optimization.


5. Latent Variables and Their Relation to Causality

Latent variables in these models are not necessarily causal. While they capture unobserved dependencies, they may not correspond to actual variables in the data-generating process.

Key Observations:

  • Representation Flexibility:
    • Latent variables provide a mechanism to model complex distributions even when the true data-generating process is unknown or partially observed.
  • Causal Ambiguity:
    • The introduced variables u may not correspond to real-world causal factors. This distinction is critical when interpreting models in causal terms.

Practical Use:

Latent variable models are often used to improve the expressive power of neural networks. For example, they enable deep models to learn representations of high-dimensional data distributions, making them valuable for generative tasks.


Conclusion

Latent variable models extend traditional probabilistic models by introducing unobserved variables, enabling the modeling of complex relationships in data. While these models provide significant flexibility, their reliance on marginalization and sampling introduces computational challenges.

Methods like posterior sampling and variational inference address these issues, making latent variable models practical for modern machine-learning tasks. However, their use in causal inference requires careful consideration, as the introduced latent variables may not correspond to actual causal factors in the data-generating process.

The Pink Ipê tree, or “Ipê Rosa,” is celebrated for its stunning clusters of pink, trumpet-shaped flowers that bloom in late winter to early spring. This tree, native to Brazil, is a striking symbol of renewal, with its vibrant blossoms often appearing before its leaves.

Original image by @ota_cardoso