calcil package

Subpackages

Submodules

calcil.data_utils module

Implementation of data utilities for loading data from numpy arrays and tf datasets.

calcil.data_utils.loader_from_numpy(input_dict, prefix_dim=None, random=True, seed=951, aux_terms=None, nojax=True)

Load data from numpy arrays.

Parameters:
  • input_dict (dict) – dictionary of numpy arrays. The first dimension is used as the batch dimension. All arrays should have the same length in the first dimension.

  • prefix_dim (tuple) – prefix dimension of the output arrays. It is used to control batch size in practice. If None, the first dimension is used.

  • random (bool) – whether to sample randomly. If False, the data is loaded in order.

  • seed (int) – random seed

  • aux_terms (dict) – auxiliary terms to be added to the output dictionary. The same terms are added to all batches.

  • nojax (bool) – whether to use jax or not

Returns:

a generator that yields a dictionary of numpy arrays.

Return type:

generator

calcil.data_utils.tfds_files_loader(tf_dataset)

calcil.dataclasses module

Utilities for defining custom classes that can be used with jax transformations.

Copyright 2022 The Flax Authors.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

calcil.dataclasses.dataclass(clz: _T) _T

Create a class which can be passed to functional transformations.

NOTE: Inherit from PyTreeNode instead to avoid type checking issues when using PyType. Jax transformations such as jax.jit and jax.grad require objects that are immutable and can be mapped over using the jax.tree_util methods. The dataclass decorator makes it easy to define custom classes that can be passed safely to Jax. For example:

from flax import struct
@struct.dataclass
class Model:
    params: Any
    # use pytree_node=False to indicate an attribute should not be touched
    # by Jax transformations.
    apply_fn: FunctionType = struct.field(pytree_node=False)
    def __apply__(self, *args):
        return self.apply_fn(*args)
model = Model(params, apply_fn)
model.params = params_b  # Model is immutable. This will raise an error.
model_b = model.replace(params=params_b)  # Use the replace method instead.
# This class can now be used safely in Jax to compute gradients w.r.t. the
# parameters.
model = Model(params, apply_fn)
model_grad = jax.grad(some_loss_fn)(model)

Note that dataclasses have an auto-generated __init__ where the arguments of the constructor and the attributed of the created instance match 1:1. This correspondance is what makes these objects valid containers that work with JAX transformations and more generally the jax.tree_util library. Sometimes a “smart constructor” is desired, for example because some of the attributes can be (optionally) derived from others. The way to do this with Flax dataclasses is to make a static or class method that provides the smart constructor. This way the simple constructor used by jax.tree_util is preserved. Consider the following example:

@struct.dataclass
class DirectionAndScaleKernel:
    direction: Array
    scale: Array
    @classmethod
    def create(cls, kernel):
        scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True)
        directin = direction / scale
        return cls(direction, scale)
Parameters:

clz – the class that will be transformed by the decorator.

Returns:

The new class.

calcil.dataclasses.get_dataclass_instances(serialize=False)
calcil.dataclasses.register_dataclass_instance(ins)
calcil.dataclasses.reset_dataclass_instances()

calcil.forward module

Implementation of forward model class for image reconstruction.

class calcil.forward.Model(*args, **kwargs)

Bases: Module

forward(*input_args, variables=None, method_name=None, rngs=None)
log_intermediate()
model_hyperparams()

serialize all model parameters, return json etc.

model_save()
name: str = None
parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
classmethod recover_from_hyperparams(s)
scope = None
var(name, mode='update', init_fn=None, shape=None, dtype=None)

Wrapper function for param and variable

var_find(variables, s)

Find a variable by keyword matching and return its unique identifier.

var_verify(variables)

Verify if the given variables is compatiable with the current model.

calcil.forward.var_list(variables)

Output a list of unique identifiers

calcil.forward.var_replace(variables, vid, value)

calcil.loss module

Implementation of Loss class for running gradient-based reconstruction.

class calcil.loss.Loss(loss_fn, name, weight=None, has_intermediates=False)

Bases: object

Loss class that wraps a loss function and its name. It is used to define the loss function of a model. The

loss function is a callable that takes the following arguments:

forward_output: output of the forward function of the model
variables: trainable variables of the model
input_dict: input dictionary from the data loader
intermediate: intermediate variables of the model (optional)
and returns a tuple of (loss, aux_dict). The aux_dict is a dictionary of auxiliary terms that are used for
logging and debugging. The loss is a scalar value.

Simple arithmetic operations are defined for the Loss class. For example, if loss_fn1 and loss_fn2 are two Loss objects, then loss_fn1 + loss_fn2 is also a Loss object. The loss function of the new Loss object is the sum of the loss functions of loss_fn1 and loss_fn2. The weights of the two loss functions are also added together. The same applies to multiplication with a scalar.

Parameters:
  • loss_fn (callable or list of callables) – loss function(s)

  • name (str or list of str) – name(s) of the loss function(s) for logging

  • weight (float or list of float) – weight(s) of the loss function(s)

  • has_intermediates (bool) – whether the loss function needs intermediate variables or not

get_loss_fn()
calcil.loss.get_l2_loss(input_key: str)
calcil.loss.get_weight_l2_reg()
calcil.loss.loss_fn_checker(fn)

decorator for loss function callable

calcil.phantom module

Phantom generation utilities.

calcil.phantom.generate_shepp_logan(dim_yx)

calcil.reconstruction module

Module for gradient descent-based reconstruction with any given differentiable forward model.

This module provides functions to perform gradient descent-based reconstruction with any given differentiable forward model. The module supports single variable reconstruction and multi-variable reconstruction with different optimization settings for each variable. The module also supports checkpoint saving and tensorboard logging.

class calcil.reconstruction.ReconIterParameters(save_dir: str, n_epoch: int, keep_checkpoints: int = 1, checkpoint_every: int = 10000, output_every: int = 1000, log_every: int = 100, log_max_imgs: int = 5)

Bases: object

Iterative reconstruction parameters.

Parameters:
  • save_dir (str) – directory to save the reconstruction results.

  • n_epoch (int) – number of epochs.

  • keep_checkpoints (int) – number of checkpoints to keep.

  • checkpoint_every (int) – save checkpoint every n epochs.

  • output_every (int) – output reconstruction by calling output_fn every n epochs.

  • log_every (int) – log loss every n epochs.

  • log_max_imgs (int) – maximum number of images to log each time.

checkpoint_every: int = 10000
keep_checkpoints: int = 1
log_every: int = 100
log_max_imgs: int = 5
n_epoch: int
output_every: int = 1000
save_dir: str
class calcil.reconstruction.ReconVarParameters(lr: float = 0, opt: str | ~optax._src.base.GradientTransformation = 'adam', opt_kwargs: Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x7fbb708d8a60>,default_factory=<class 'dict'>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=None) = None, schedule: str | ~typing.Callable[[~jax._src.basearray.Array | ~numpy.ndarray | ~jax.interpreters.batching.BatchTracer | ~jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase | float | int], ~jax._src.basearray.Array | ~numpy.ndarray | ~jax.interpreters.batching.BatchTracer | ~jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase | float | int] = 'constant_schedule', schedule_kwargs: Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x7fbb708d8a60>,default_factory=<class 'dict'>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=None) = None, delay_update_n_iter: int = 0, update_every: int = 1)

Bases: object

Optimization parameters for a set of variables.

Parameters:
  • lr (float) – learning rate. Set to 0 to freeze the variable.

  • opt (Union[str, optax.GradientTransformation]) – optimizer. Either a string to specify the optimizer in optax or an optax.GradientTransformation object.

  • opt_kwargs (dict) – a dictionary of keyword arguments for the optimizer.

  • schedule (Union[str, optax.Schedule]) – learning rate schedule. Either a string to specify the schedule in optax or an optax.Schedule object.

  • schedule_kwargs (dict) – a dictionary of keyword arguments for the learning rate schedule.

  • delay_update_n_iter (int) – number of iterations to delay the update of the variable. The optimization will do nothing for the first n iterations. Note that this goes with iteration (step), not epoch.

  • update_every (int) – update the variable every n iterations. Default is 1 (update every iteration). When set to n, the gradient will be accumulated for n iterations before updating the variable. See: optax.apply_every. Note that this goes with iteration (step), not epoch.

delay_update_n_iter: int = 0
lr: float = 0
opt: str | GradientTransformation = 'adam'
opt_kwargs: Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x7fbb708d8a60>,default_factory=<class 'dict'>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=None) = None
schedule: str | Callable[[Array | ndarray | BatchTracer | ShardedDeviceArrayBase | float | int], Array | ndarray | BatchTracer | ShardedDeviceArrayBase | float | int] = 'constant_schedule'
schedule_kwargs: Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x7fbb708d8a60>,default_factory=<class 'dict'>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=None) = None
update_every: int = 1
calcil.reconstruction.generate_nested_dict_keys(d)

Generate a list of keys for a nested dictionary. A key is formatted as levelkey1_levelkey2_..._0.

Parameters:

d (dict) – nested dictionary.

Returns:

list of keys.

Return type:

list

calcil.reconstruction.load_checkpoint_and_output(load_path, output_fn=None)

Load a checkpoint and optionally output the variables or reconstructions.

Parameters:
  • load_path (str) – path to the checkpoint.

  • output_fn (callable) – output function that renders output images. output_fn takes the reconstruction variables as input and returns a dictionary of output images.

Returns:

reconstructed variables in the checkpoint. dict: output_dict if output_fn is given.

Return type:

dict

calcil.reconstruction.load_tensorboard_log(log_path, tag, is_image=False)

Load tensorboard log from a file.

Parameters:
  • log_path (str) – path to the log file.

  • tag (str) – tag of the tensorboard variable.

  • is_image (bool) – whether the tensorboard variable is an image.

Returns:

list of tensorboard variable values. list: list of iteration numbers corresponding to the tensorboard variable values.

Return type:

list

calcil.reconstruction.reconstruct_multivars_sgd(forward_fn: Callable, variables: Dict | FrozenDict, var_params_pytree: Dict, data_loader: Generator, loss: Loss, recon_param: ReconIterParameters, output_fn: Callable | None = None, post_update_handler: Callable | None = None, rngs: Dict | None = None, output_info: bool = False)

Reconstruct multiple variables with different optimization settings using SGD.

Parameters:
  • forward_fn (callable) – forward model of the system

  • variables (dict) – initial values of the reconstruction variables

  • var_params_pytree (dict) – nested dictionary of optimization parameters for each variable. all variables need to be included in the dictionary. Set the learning rate to 0 to freeze the variable. The keys of the dictionary should correspond to the keys of the variables dictionary. Each entry is a ReconVarParameters object (see calcil.reconstruction.ReconVarParameters).

  • data_loader (generator) – generator that yields input dictionary (see calcil.data_utils.loader_from_numpy)

  • loss (Loss) – loss function (see calcil.loss)

  • recon_param (ReconIterParameters) – reconstruction settings (see calcil.reconstruction.ReconIterParameters)

  • output_fn (callable) – output function that renders output images. output_fn takes the reconstruction variables and train state (from which forward_fn can be called with the variables) as input and returns a dictionary of output images.

  • post_update_handler (callable) – post update handler that is called after each epoch

  • rngs (dict) – jax random number generators

  • output_info (bool) – whether to output reconstruction info

Returns:

final reconstructed variables at the end of the reconstruction. dict: dictionary of reconstructed images at different iterations dict: dictionary of reconstruction info at different iterations if output_info is True

Return type:

dict

calcil.reconstruction.reconstruct_sgd(forward_fn: Callable, variables: Dict | FrozenDict, data_loader: Generator, loss: Loss, var_params: ReconVarParameters, recon_param: ReconIterParameters, output_fn: Callable | None = None, post_update_handler: Callable | None = None, rngs: Dict | None = None, output_info: bool = False)

Reconstruct variables using a single optimization setting with SGD.

Parameters:
  • forward_fn (callable) – forward model of the system.

  • variables (dict) – initial values of the reconstruction variables.

  • data_loader (generator) – generator that yields input dictionary (see calcil.data_utils.loader_from_numpy).

  • loss (Loss) – loss function (see calcil.loss).

  • var_params (ReconVarParameters) – optimization parameters (see calcil.reconstruction.ReconVarParameters).

  • recon_param (ReconIterParameters) – reconstruction settings (see calcil.reconstruction.ReconIterParameters).

  • output_fn (callable) – output function that renders output images. output_fn takes the reconstruction variables and train state (from which forward_fn can be called with the variables) as input and returns a dictionary of output images.

  • post_update_handler (callable) – post update handler that is called after each epoch.

  • rngs (dict) – random number generators.

  • output_info (bool) – whether to output reconstruction info.

Returns:

final reconstructed variables at the end of the reconstruction. dict: dictionary of output images generated by output_fn at different iterations. dict: dictionary of reconstruction info at different iterations if output_info is True.

Return type:

dict

calcil.reconstruction.run_reconstruction(state: TrainState, data_loader: Generator, loss: Loss, recon_param: ReconIterParameters, output_fn: Callable | None, post_update_handler: Callable, rngs: Dict | None)

Run SGD reconstruction with flax.train.TrainState.

Parameters:
  • state (flax.train.TrainState) – initial state of the reconstruction

  • data_loader (generator) – generator that yields input dictionary

  • loss (Loss) – loss function

  • recon_param (ReconIterParameters) – reconstruction settings

  • output_fn (callable) – output function that renders output images. output_fn takes the reconstruction variables and train state (from which forward_fn can be called with the variables) as input and returns a dictionary of output images.

  • post_update_handler (callable) – post update handler

  • rngs (dict) – random number generators

Returns:

final reconstructed variables at the end of the reconstruction. dict: dictionary of reconstructed images at different iterations dict: dictionary of reconstruction info at different iterations

Return type:

dict

calcil.reconstruction.update_iter_sgd(state: TrainState, input_dict, rngs, loss_fn)

Update function for SGD reconstruction.

Parameters:
  • state (flax.train.TrainState) – current state of the reconstruction.

  • input_dict (dict) – input dictionary.

  • rngs (dict) – dict of random number generators for potential randomness in the forward model.

  • loss_fn (callable) – loss function.

Returns:

new state of the reconstruction. dict: reconstruction info.

Return type:

flax.train.TrainState

calcil.signal module

Signal processing utilities.

calcil.signal.fftconvolve(in1, in2, mode='full', axes=None)

Module contents