calcil package
Subpackages
Submodules
calcil.data_utils module
Implementation of data utilities for loading data from numpy arrays and tf datasets.
- calcil.data_utils.loader_from_numpy(input_dict, prefix_dim=None, random=True, seed=951, aux_terms=None, nojax=True)
Load data from numpy arrays.
- Parameters:
input_dict (dict) – dictionary of numpy arrays. The first dimension is used as the batch dimension. All arrays should have the same length in the first dimension.
prefix_dim (tuple) – prefix dimension of the output arrays. It is used to control batch size in practice. If None, the first dimension is used.
random (bool) – whether to sample randomly. If False, the data is loaded in order.
seed (int) – random seed
aux_terms (dict) – auxiliary terms to be added to the output dictionary. The same terms are added to all batches.
nojax (bool) – whether to use jax or not
- Returns:
a generator that yields a dictionary of numpy arrays.
- Return type:
generator
- calcil.data_utils.tfds_files_loader(tf_dataset)
calcil.dataclasses module
Utilities for defining custom classes that can be used with jax transformations.
Copyright 2022 The Flax Authors.
Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
- calcil.dataclasses.dataclass(clz: _T) _T
Create a class which can be passed to functional transformations.
NOTE: Inherit from
PyTreeNodeinstead to avoid type checking issues when using PyType. Jax transformations such as jax.jit and jax.grad require objects that are immutable and can be mapped over using the jax.tree_util methods. The dataclass decorator makes it easy to define custom classes that can be passed safely to Jax. For example:from flax import struct @struct.dataclass class Model: params: Any # use pytree_node=False to indicate an attribute should not be touched # by Jax transformations. apply_fn: FunctionType = struct.field(pytree_node=False) def __apply__(self, *args): return self.apply_fn(*args) model = Model(params, apply_fn) model.params = params_b # Model is immutable. This will raise an error. model_b = model.replace(params=params_b) # Use the replace method instead. # This class can now be used safely in Jax to compute gradients w.r.t. the # parameters. model = Model(params, apply_fn) model_grad = jax.grad(some_loss_fn)(model)
Note that dataclasses have an auto-generated
__init__where the arguments of the constructor and the attributed of the created instance match 1:1. This correspondance is what makes these objects valid containers that work with JAX transformations and more generally the jax.tree_util library. Sometimes a “smart constructor” is desired, for example because some of the attributes can be (optionally) derived from others. The way to do this with Flax dataclasses is to make a static or class method that provides the smart constructor. This way the simple constructor used by jax.tree_util is preserved. Consider the following example:@struct.dataclass class DirectionAndScaleKernel: direction: Array scale: Array @classmethod def create(cls, kernel): scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True) directin = direction / scale return cls(direction, scale)
- Parameters:
clz – the class that will be transformed by the decorator.
- Returns:
The new class.
- calcil.dataclasses.get_dataclass_instances(serialize=False)
- calcil.dataclasses.register_dataclass_instance(ins)
- calcil.dataclasses.reset_dataclass_instances()
calcil.forward module
Implementation of forward model class for image reconstruction.
- class calcil.forward.Model(*args, **kwargs)
Bases:
Module- forward(*input_args, variables=None, method_name=None, rngs=None)
- log_intermediate()
- model_hyperparams()
serialize all model parameters, return json etc.
- model_save()
- name: str = None
- parent: Type[Module] | Type[Scope] | Type[_Sentinel] | None = <flax.linen.module._Sentinel object>
- classmethod recover_from_hyperparams(s)
- scope = None
- var(name, mode='update', init_fn=None, shape=None, dtype=None)
Wrapper function for param and variable
- var_find(variables, s)
Find a variable by keyword matching and return its unique identifier.
- var_verify(variables)
Verify if the given variables is compatiable with the current model.
- calcil.forward.var_list(variables)
Output a list of unique identifiers
- calcil.forward.var_replace(variables, vid, value)
calcil.loss module
Implementation of Loss class for running gradient-based reconstruction.
- class calcil.loss.Loss(loss_fn, name, weight=None, has_intermediates=False)
Bases:
objectLoss class that wraps a loss function and its name. It is used to define the loss function of a model. The
loss function is a callable that takes the following arguments:
forward_output: output of the forward function of the model variables: trainable variables of the model input_dict: input dictionary from the data loader intermediate: intermediate variables of the model (optional) and returns a tuple of (loss, aux_dict). The aux_dict is a dictionary of auxiliary terms that are used for logging and debugging. The loss is a scalar value.
Simple arithmetic operations are defined for the Loss class. For example, if loss_fn1 and loss_fn2 are two Loss objects, then loss_fn1 + loss_fn2 is also a Loss object. The loss function of the new Loss object is the sum of the loss functions of loss_fn1 and loss_fn2. The weights of the two loss functions are also added together. The same applies to multiplication with a scalar.
- Parameters:
loss_fn (callable or list of callables) – loss function(s)
name (str or list of str) – name(s) of the loss function(s) for logging
weight (float or list of float) – weight(s) of the loss function(s)
has_intermediates (bool) – whether the loss function needs intermediate variables or not
- get_loss_fn()
- calcil.loss.get_l2_loss(input_key: str)
- calcil.loss.get_weight_l2_reg()
- calcil.loss.loss_fn_checker(fn)
decorator for loss function callable
calcil.phantom module
Phantom generation utilities.
- calcil.phantom.generate_shepp_logan(dim_yx)
calcil.reconstruction module
Module for gradient descent-based reconstruction with any given differentiable forward model.
This module provides functions to perform gradient descent-based reconstruction with any given differentiable forward model. The module supports single variable reconstruction and multi-variable reconstruction with different optimization settings for each variable. The module also supports checkpoint saving and tensorboard logging.
- class calcil.reconstruction.ReconIterParameters(save_dir: str, n_epoch: int, keep_checkpoints: int = 1, checkpoint_every: int = 10000, output_every: int = 1000, log_every: int = 100, log_max_imgs: int = 5)
Bases:
objectIterative reconstruction parameters.
- Parameters:
save_dir (str) – directory to save the reconstruction results.
n_epoch (int) – number of epochs.
keep_checkpoints (int) – number of checkpoints to keep.
checkpoint_every (int) – save checkpoint every n epochs.
output_every (int) – output reconstruction by calling output_fn every n epochs.
log_every (int) – log loss every n epochs.
log_max_imgs (int) – maximum number of images to log each time.
- checkpoint_every: int = 10000
- keep_checkpoints: int = 1
- log_every: int = 100
- log_max_imgs: int = 5
- n_epoch: int
- output_every: int = 1000
- save_dir: str
- class calcil.reconstruction.ReconVarParameters(lr: float = 0, opt: str | ~optax._src.base.GradientTransformation = 'adam', opt_kwargs: Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x7fb44bcb8a30>,default_factory=<class 'dict'>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=None) = None, schedule: str | ~typing.Callable[[~jax._src.basearray.Array | ~numpy.ndarray | ~jax.interpreters.batching.BatchTracer | ~jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase | float | int], ~jax._src.basearray.Array | ~numpy.ndarray | ~jax.interpreters.batching.BatchTracer | ~jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase | float | int] = 'constant_schedule', schedule_kwargs: Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x7fb44bcb8a30>,default_factory=<class 'dict'>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=None) = None, delay_update_n_iter: int = 0, update_every: int = 1)
Bases:
objectOptimization parameters for a set of variables.
- Parameters:
lr (float) – learning rate. Set to 0 to freeze the variable.
opt (Union[str, optax.GradientTransformation]) – optimizer. Either a string to specify the optimizer in optax or an optax.GradientTransformation object.
opt_kwargs (dict) – a dictionary of keyword arguments for the optimizer.
schedule (Union[str, optax.Schedule]) – learning rate schedule. Either a string to specify the schedule in optax or an optax.Schedule object.
schedule_kwargs (dict) – a dictionary of keyword arguments for the learning rate schedule.
delay_update_n_iter (int) – number of iterations to delay the update of the variable. The optimization will do nothing for the first n iterations. Note that this goes with iteration (step), not epoch.
update_every (int) – update the variable every n iterations. Default is 1 (update every iteration). When set to n, the gradient will be accumulated for n iterations before updating the variable. See: optax.apply_every. Note that this goes with iteration (step), not epoch.
- delay_update_n_iter: int = 0
- lr: float = 0
- opt: str | GradientTransformation = 'adam'
- opt_kwargs: Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x7fb44bcb8a30>,default_factory=<class 'dict'>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=None) = None
- schedule: str | Callable[[Array | ndarray | BatchTracer | ShardedDeviceArrayBase | float | int], Array | ndarray | BatchTracer | ShardedDeviceArrayBase | float | int] = 'constant_schedule'
- schedule_kwargs: Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x7fb44bcb8a30>,default_factory=<class 'dict'>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),_field_type=None) = None
- update_every: int = 1
- calcil.reconstruction.generate_nested_dict_keys(d)
Generate a list of keys for a nested dictionary. A key is formatted as
levelkey1_levelkey2_..._0.- Parameters:
d (dict) – nested dictionary.
- Returns:
list of keys.
- Return type:
list
- calcil.reconstruction.load_checkpoint_and_output(load_path, output_fn=None)
Load a checkpoint and optionally output the variables or reconstructions.
- Parameters:
load_path (str) – path to the checkpoint.
output_fn (callable) – output function that renders output images. output_fn takes the reconstruction variables as input and returns a dictionary of output images.
- Returns:
reconstructed variables in the checkpoint. dict: output_dict if output_fn is given.
- Return type:
dict
- calcil.reconstruction.load_tensorboard_log(log_path, tag, is_image=False)
Load tensorboard log from a file.
- Parameters:
log_path (str) – path to the log file.
tag (str) – tag of the tensorboard variable.
is_image (bool) – whether the tensorboard variable is an image.
- Returns:
list of tensorboard variable values. list: list of iteration numbers corresponding to the tensorboard variable values.
- Return type:
list
- calcil.reconstruction.reconstruct_multivars_sgd(forward_fn: Callable, variables: Dict | FrozenDict, var_params_pytree: Dict, data_loader: Generator, loss: Loss, recon_param: ReconIterParameters, output_fn: Callable | None = None, post_update_handler: Callable | None = None, rngs: Dict | None = None, output_info: bool = False)
Reconstruct multiple variables with different optimization settings using SGD.
- Parameters:
forward_fn (callable) – forward model of the system
variables (dict) – initial values of the reconstruction variables
var_params_pytree (dict) – nested dictionary of optimization parameters for each variable. all variables need to be included in the dictionary. Set the learning rate to 0 to freeze the variable. The keys of the dictionary should correspond to the keys of the variables dictionary. Each entry is a ReconVarParameters object (see calcil.reconstruction.ReconVarParameters).
data_loader (generator) – generator that yields input dictionary (see calcil.data_utils.loader_from_numpy)
loss (Loss) – loss function (see calcil.loss)
recon_param (ReconIterParameters) – reconstruction settings (see calcil.reconstruction.ReconIterParameters)
output_fn (callable) – output function that renders output images. output_fn takes the reconstruction variables and train state (from which forward_fn can be called with the variables) as input and returns a dictionary of output images.
post_update_handler (callable) – post update handler that is called after each epoch
rngs (dict) – jax random number generators
output_info (bool) – whether to output reconstruction info
- Returns:
final reconstructed variables at the end of the reconstruction. dict: dictionary of reconstructed images at different iterations dict: dictionary of reconstruction info at different iterations if output_info is True
- Return type:
dict
- calcil.reconstruction.reconstruct_sgd(forward_fn: Callable, variables: Dict | FrozenDict, data_loader: Generator, loss: Loss, var_params: ReconVarParameters, recon_param: ReconIterParameters, output_fn: Callable | None = None, post_update_handler: Callable | None = None, rngs: Dict | None = None, output_info: bool = False)
Reconstruct variables using a single optimization setting with SGD.
- Parameters:
forward_fn (callable) – forward model of the system.
variables (dict) – initial values of the reconstruction variables.
data_loader (generator) – generator that yields input dictionary (see calcil.data_utils.loader_from_numpy).
loss (Loss) – loss function (see calcil.loss).
var_params (ReconVarParameters) – optimization parameters (see calcil.reconstruction.ReconVarParameters).
recon_param (ReconIterParameters) – reconstruction settings (see calcil.reconstruction.ReconIterParameters).
output_fn (callable) – output function that renders output images. output_fn takes the reconstruction variables and train state (from which forward_fn can be called with the variables) as input and returns a dictionary of output images.
post_update_handler (callable) – post update handler that is called after each epoch.
rngs (dict) – random number generators.
output_info (bool) – whether to output reconstruction info.
- Returns:
final reconstructed variables at the end of the reconstruction. dict: dictionary of output images generated by output_fn at different iterations. dict: dictionary of reconstruction info at different iterations if output_info is True.
- Return type:
dict
- calcil.reconstruction.run_reconstruction(state: TrainState, data_loader: Generator, loss: Loss, recon_param: ReconIterParameters, output_fn: Callable | None, post_update_handler: Callable, rngs: Dict | None)
Run SGD reconstruction with flax.train.TrainState.
- Parameters:
state (flax.train.TrainState) – initial state of the reconstruction
data_loader (generator) – generator that yields input dictionary
loss (Loss) – loss function
recon_param (ReconIterParameters) – reconstruction settings
output_fn (callable) – output function that renders output images. output_fn takes the reconstruction variables and train state (from which forward_fn can be called with the variables) as input and returns a dictionary of output images.
post_update_handler (callable) – post update handler
rngs (dict) – random number generators
- Returns:
final reconstructed variables at the end of the reconstruction. dict: dictionary of reconstructed images at different iterations dict: dictionary of reconstruction info at different iterations
- Return type:
dict
- calcil.reconstruction.update_iter_sgd(state: TrainState, input_dict, rngs, loss_fn)
Update function for SGD reconstruction.
- Parameters:
state (flax.train.TrainState) – current state of the reconstruction.
input_dict (dict) – input dictionary.
rngs (dict) – dict of random number generators for potential randomness in the forward model.
loss_fn (callable) – loss function.
- Returns:
new state of the reconstruction. dict: reconstruction info.
- Return type:
flax.train.TrainState
calcil.signal module
Signal processing utilities.
- calcil.signal.fftconvolve(in1, in2, mode='full', axes=None)